未验证 提交 af127342 编写于 作者: 张经纬 提交者: GitHub

[CodeStyle][CINN] format cpp code via clang-format (#54961)

* fix clang-format

* 'fix_clang-format'

* fix remaining errors

* format

* empty commit, re-trigger all ci

* empty commit, re-trigger all ci

---------
Co-authored-by: NSigureMo <sigure.qaq@gmail.com>
上级 a7419ff5
...@@ -47,7 +47,8 @@ repos: ...@@ -47,7 +47,8 @@ repos:
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
exclude: | exclude: |
(?x)^( (?x)^(
paddle/utils/.* paddle/utils/.*|
paddle/cinn/utils/registry.h
)$ )$
# For Python files # For Python files
- repo: https://github.com/psf/black.git - repo: https://github.com/psf/black.git
......
...@@ -41,7 +41,7 @@ std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) { ...@@ -41,7 +41,7 @@ std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
for (const ir::Expr& e : indices) { for (const ir::Expr& e : indices) {
// Whether we have to convert other types, like const numbers to Var? // Whether we have to convert other types, like const numbers to Var?
if (e.As<ir::_Var_>() != nullptr) { if (e.As<ir::_Var_>() != nullptr) {
ir::Expr copy_e = optim::IRCopy(e); ir::Expr copy_e = optim::IRCopy(e);
ir::_Var_* var_ref = copy_e.As<ir::_Var_>(); ir::_Var_* var_ref = copy_e.As<ir::_Var_>();
result.emplace_back(ir::Var(var_ref)); result.emplace_back(ir::Var(var_ref));
} }
...@@ -58,26 +58,32 @@ void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) { ...@@ -58,26 +58,32 @@ void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) {
const ir::Load* load_expr = x->As<ir::Load>(); const ir::Load* load_expr = x->As<ir::Load>();
if (load_expr != nullptr) { if (load_expr != nullptr) {
const ir::Tensor t = load_expr->tensor.as_tensor_ref(); const ir::Tensor t = load_expr->tensor.as_tensor_ref();
sche_block->read_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices))); sche_block->read_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices)));
return false; return false;
} }
const ir::Store* store_expr = x->As<ir::Store>(); const ir::Store* store_expr = x->As<ir::Store>();
if (store_expr != nullptr) { if (store_expr != nullptr) {
const ir::Tensor t = store_expr->tensor.as_tensor_ref(); const ir::Tensor t = store_expr->tensor.as_tensor_ref();
sche_block->write_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices))); sche_block->write_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices)));
return false; return false;
} }
return false; return false;
}); });
} }
bool ContainsNodeType(ir::Expr expr, const std::unordered_set<ir::IrNodeTy>& node_types) { bool ContainsNodeType(ir::Expr expr,
std::set<ir::Expr> collection = ir::CollectIRNodesWithoutTensor( const std::unordered_set<ir::IrNodeTy>& node_types) {
expr, [&](const Expr* x) { return node_types.find(x->node_type()) != node_types.end(); }); std::set<ir::Expr> collection =
ir::CollectIRNodesWithoutTensor(expr, [&](const Expr* x) {
return node_types.find(x->node_type()) != node_types.end();
});
return !collection.empty(); return !collection.empty();
} }
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<ir::LoweredFunc>& lowered_funcs) { std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(
const std::vector<ir::LoweredFunc>& lowered_funcs) {
std::unordered_set<std::string> result; std::unordered_set<std::string> result;
for (const ir::LoweredFunc& func : lowered_funcs) { for (const ir::LoweredFunc& func : lowered_funcs) {
for (const ir::Argument& arg : func->args) { for (const ir::Argument& arg : func->args) {
...@@ -90,18 +96,22 @@ std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector< ...@@ -90,18 +96,22 @@ std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<
} }
bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) { bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) {
const ir::ScheduleBlock* sche_block = sche_block_realize.schedule_block.As<ir::ScheduleBlock>(); const ir::ScheduleBlock* sche_block =
if (sche_block->write_buffers.size() != 1 || sche_block->read_buffers.empty()) { sche_block_realize.schedule_block.As<ir::ScheduleBlock>();
if (sche_block->write_buffers.size() != 1 ||
sche_block->read_buffers.empty()) {
return false; return false;
} }
const ir::Expr& write_buffer = sche_block->write_buffers[0].As<ir::_BufferRange_>()->buffer; const ir::Expr& write_buffer =
sche_block->write_buffers[0].As<ir::_BufferRange_>()->buffer;
// Enumerate each read region, get the number of schedule block iter vars // Enumerate each read region, get the number of schedule block iter vars
// which are not used to index the read region // which are not used to index the read region
int total_unused_iter_vars = 0; int total_unused_iter_vars = 0;
for (const ir::Expr& read_buffer_expr : sche_block->read_buffers) { for (const ir::Expr& read_buffer_expr : sche_block->read_buffers) {
const ir::_BufferRange_* read_buffer = read_buffer_expr.As<ir::_BufferRange_>(); const ir::_BufferRange_* read_buffer =
read_buffer_expr.As<ir::_BufferRange_>();
// Skip the reduction buffer // Skip the reduction buffer
if (read_buffer->buffer == write_buffer) { if (read_buffer->buffer == write_buffer) {
continue; continue;
...@@ -133,18 +143,22 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) { ...@@ -133,18 +143,22 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) {
return total_unused_iter_vars >= 1; return total_unused_iter_vars >= 1;
} }
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body) { ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body) {
ir::ModuleExpr mod_expr(std::vector<ir::Expr>({body})); ir::ModuleExpr mod_expr(std::vector<ir::Expr>({body}));
ir::IRSchedule ir_sch(mod_expr); ir::IRSchedule ir_sch(mod_expr);
// temp_bufs may be deleted during auto tuning (such as auto inline), // temp_bufs may be deleted during auto tuning (such as auto inline),
// we have to check from old temp bufs and set them as local buffer. // we have to check from old temp bufs and set them as local buffer.
for (const ir::Buffer& buf : old_func->temp_bufs) { for (const ir::Buffer& buf : old_func->temp_bufs) {
const std::string& buf_name = buf->name; const std::string& buf_name = buf->name;
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks(); std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
for (ir::Expr& e : all_block_realizes) { for (ir::Expr& e : all_block_realizes) {
const ir::ScheduleBlockRealize* sche_block_realize = e.As<ir::ScheduleBlockRealize>(); const ir::ScheduleBlockRealize* sche_block_realize =
const std::string& sche_name = sche_block_realize->schedule_block.As<ir::ScheduleBlock>()->name; e.As<ir::ScheduleBlockRealize>();
const std::string& sche_name =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>()->name;
if (buf_name == "_" + sche_name) { if (buf_name == "_" + sche_name) {
VLOG(6) << "Set local buffer for temp buffer " << buf_name; VLOG(6) << "Set local buffer for temp buffer " << buf_name;
ir_sch.SetBuffer(e, "local", true); ir_sch.SetBuffer(e, "local", true);
...@@ -159,14 +173,17 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo ...@@ -159,14 +173,17 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo
#endif #endif
// Get new temp bufs by analyzing. // Get new temp bufs by analyzing.
std::vector<ir::Buffer> new_temp_bufs = lang::GetTempBuffers(old_func->args, updated_body); std::vector<ir::Buffer> new_temp_bufs =
ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(old_func->name, old_func->args, updated_body, new_temp_bufs); lang::GetTempBuffers(old_func->args, updated_body);
ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(
old_func->name, old_func->args, updated_body, new_temp_bufs);
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
if (target == common::DefaultNVGPUTarget()) { if (target == common::DefaultNVGPUTarget()) {
new_func->PrepareCudaAxisInfoFromBody(); new_func->PrepareCudaAxisInfoFromBody();
} }
#endif #endif
new_func = optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref(); new_func =
optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref();
new_func->PrepareBufferCastExprs(/*with_expr_gen_tensor = */ false); new_func->PrepareBufferCastExprs(/*with_expr_gen_tensor = */ false);
return new_func; return new_func;
......
...@@ -27,12 +27,14 @@ namespace auto_schedule { ...@@ -27,12 +27,14 @@ namespace auto_schedule {
void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block); void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block);
bool ContainsNodeType(ir::Expr expr, const std::unordered_set<ir::IrNodeTy>& node_types); bool ContainsNodeType(ir::Expr expr,
const std::unordered_set<ir::IrNodeTy>& node_types);
/** /**
* Collects all input lowered_funcs and return names of all output arguments * Collects all input lowered_funcs and return names of all output arguments
*/ */
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<ir::LoweredFunc>& lowered_funcs); std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(
const std::vector<ir::LoweredFunc>& lowered_funcs);
/** /**
* Determine whether a schedule block needs multileveltiling * Determine whether a schedule block needs multileveltiling
...@@ -42,7 +44,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize); ...@@ -42,7 +44,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize);
/** /**
* Update a LoweredFunc by regenerating related fields with a new function body * Update a LoweredFunc by regenerating related fields with a new function body
*/ */
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body); ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body);
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -49,8 +49,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { ...@@ -49,8 +49,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) {
ir::Tensor B = lang::Compute( ir::Tensor B = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); {M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B}); poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ASSERT_FALSE(funcs.empty()); ASSERT_FALSE(funcs.empty());
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
...@@ -65,8 +66,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { ...@@ -65,8 +66,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) {
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks(); std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
ASSERT_EQ(all_block_realizes.size(), 1UL); ASSERT_EQ(all_block_realizes.size(), 1UL);
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block); AnalyzeScheduleBlockReadWriteBuffer(sche_block);
/* /*
...@@ -112,8 +115,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { ...@@ -112,8 +115,9 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) {
ir::Tensor C = lang::Compute( ir::Tensor C = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = poly::CreateStages({C}); poly::StageMap stages = poly::CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("AddDiffShape", stages, {C}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"AddDiffShape", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << "Expr before MultiLevelTiling: ";
...@@ -126,8 +130,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { ...@@ -126,8 +130,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) {
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks(); std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
ASSERT_EQ(all_block_realizes.size(), 1UL); ASSERT_EQ(all_block_realizes.size(), 1UL);
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block); AnalyzeScheduleBlockReadWriteBuffer(sche_block);
VLOG(6) << "ScheduleBlockRealize: "; VLOG(6) << "ScheduleBlockRealize: ";
...@@ -163,8 +169,9 @@ TEST(AnalyzeIr, ContainsNodeType) { ...@@ -163,8 +169,9 @@ TEST(AnalyzeIr, ContainsNodeType) {
ir::Tensor B = lang::Compute( ir::Tensor B = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); {M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B}); poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ASSERT_FALSE(funcs.empty()); ASSERT_FALSE(funcs.empty());
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
...@@ -172,9 +179,12 @@ TEST(AnalyzeIr, ContainsNodeType) { ...@@ -172,9 +179,12 @@ TEST(AnalyzeIr, ContainsNodeType) {
VLOG(6) << "Analyzing for Expr:"; VLOG(6) << "Analyzing for Expr:";
VLOG(6) << ast_expr; VLOG(6) << ast_expr;
ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store})); ASSERT_TRUE(
ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse})); ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store}));
ASSERT_FALSE(ContainsNodeType(ast_expr, {ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum})); ASSERT_TRUE(ContainsNodeType(ast_expr,
{ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse}));
ASSERT_FALSE(ContainsNodeType(ast_expr,
{ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum}));
} }
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -38,13 +38,17 @@ ...@@ -38,13 +38,17 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
AutoTuner::AutoTuner(const common::Target& target, hlir::framework::Graph* graph) : target_(target), graph_(graph) {} AutoTuner::AutoTuner(const common::Target& target,
hlir::framework::Graph* graph)
: target_(target), graph_(graph) {}
void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler) { void AutoTuner::Initialize(const Config& config,
hlir::framework::GraphCompiler* graph_compiler) {
// create builder, runner, and schedule measurer // create builder, runner, and schedule measurer
builder_ = std::make_unique<SimpleBuilder>(graph_compiler); builder_ = std::make_unique<SimpleBuilder>(graph_compiler);
runner_ = std::make_unique<SimpleRunner>(config.runner_repeat_times); runner_ = std::make_unique<SimpleRunner>(config.runner_repeat_times);
schedule_measurer_ = std::make_unique<ScheduleMeasurer>(builder_.get(), runner_.get()); schedule_measurer_ =
std::make_unique<ScheduleMeasurer>(builder_.get(), runner_.get());
// initialize database // initialize database
database_ = std::move(Database::Make(config.database_config)); database_ = std::move(Database::Make(config.database_config));
...@@ -53,29 +57,43 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler* ...@@ -53,29 +57,43 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler*
TaskCreator task_creator; TaskCreator task_creator;
tasks_ = task_creator.CreateTuneTaskOpLevel(graph_); tasks_ = task_creator.CreateTuneTaskOpLevel(graph_);
const auto& dtype_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); const auto& dtype_dict =
const auto& shape_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph_->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target_); op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target_);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto i = 0; i < tasks_.size(); ++i) { for (auto i = 0; i < tasks_.size(); ++i) {
auto&& task = tasks_[i]; auto&& task = tasks_[i];
task.Initialize(shape_dict, dtype_dict, op_lowerer_.get()); task.Initialize(shape_dict, dtype_dict, op_lowerer_.get());
// Register the initial ModuleExpr corresponding to the task // Register the initial ModuleExpr corresponding to the task
task_registry->Regist(task.serialized_key, ir::ModuleExpr(task.GetLoweredFuncBodyExprs())); task_registry->Regist(task.serialized_key,
VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n" << task.serialized_key; ir::ModuleExpr(task.GetLoweredFuncBodyExprs()));
VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n"
<< task.serialized_key;
} }
// create task optimizers // create task optimizers
utils::LinearRandomEngine::StateType initial_seed = utils::LinearRandomEngine::GetDeviceRandomValue(); utils::LinearRandomEngine::StateType initial_seed =
utils::LinearRandomEngine::GetDeviceRandomValue();
task_optimizers_.resize(tasks_.size()); task_optimizers_.resize(tasks_.size());
std::transform(tasks_.begin(), tasks_.end(), task_optimizers_.begin(), [&](TuneTask& task) { std::transform(tasks_.begin(),
return std::make_unique<TaskOptimizer>( tasks_.end(),
&task, schedule_measurer_.get(), database_.get(), utils::ForkRandomState(&initial_seed)); task_optimizers_.begin(),
}); [&](TuneTask& task) {
return std::make_unique<TaskOptimizer>(
&task,
schedule_measurer_.get(),
database_.get(),
utils::ForkRandomState(&initial_seed));
});
// create task scheduler // create task scheduler
task_scheduler_ = TaskScheduler::Make(tasks_, config.task_schedule_config, config.task_schedule_strategy); task_scheduler_ = TaskScheduler::Make(
tasks_, config.task_schedule_config, config.task_schedule_strategy);
} }
void PrintResult(std::shared_ptr<hlir::framework::Graph::Group> group) { void PrintResult(std::shared_ptr<hlir::framework::Graph::Group> group) {
...@@ -127,7 +145,8 @@ void PrintResult(const TuningResult& result) { ...@@ -127,7 +145,8 @@ void PrintResult(const TuningResult& result) {
TuningResult AutoTuner::Tune(const TuningOptions& options) { TuningResult AutoTuner::Tune(const TuningOptions& options) {
CHECK_GT(options.num_tuning_rounds, 0) << "Invalid config"; CHECK_GT(options.num_tuning_rounds, 0) << "Invalid config";
VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds << ", tasks size=" << tasks_.size(); VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds
<< ", tasks size=" << tasks_.size();
TuningResult result; TuningResult result;
result.subgraphs.resize(tasks_.size()); result.subgraphs.resize(tasks_.size());
...@@ -136,7 +155,7 @@ TuningResult AutoTuner::Tune(const TuningOptions& options) { ...@@ -136,7 +155,7 @@ TuningResult AutoTuner::Tune(const TuningOptions& options) {
// as default result of graph tuning, and that should be updated // as default result of graph tuning, and that should be updated
// once we support graph tuning. // once we support graph tuning.
for (auto i = 0; i < tasks_.size(); ++i) { for (auto i = 0; i < tasks_.size(); ++i) {
auto&& task = tasks_.at(i); auto&& task = tasks_.at(i);
result.subgraphs[i] = task.subgraph; result.subgraphs[i] = task.subgraph;
} }
...@@ -146,7 +165,7 @@ TuningResult AutoTuner::Tune(const TuningOptions& options) { ...@@ -146,7 +165,7 @@ TuningResult AutoTuner::Tune(const TuningOptions& options) {
task_scheduler_->Reset(); task_scheduler_->Reset();
while ((run_id = task_scheduler_->NextTaskId()) != -1) { while ((run_id = task_scheduler_->NextTaskId()) != -1) {
VLOG(3) << "Start tuning Task-" << run_id; VLOG(3) << "Start tuning Task-" << run_id;
auto* opt = task_optimizers_.at(run_id).get(); auto* opt = task_optimizers_.at(run_id).get();
auto function_group = opt->Optimize(options); auto function_group = opt->Optimize(options);
VLOG(3) << "Task-" << run_id << " finished, print optimized functions:\n"; VLOG(3) << "Task-" << run_id << " finished, print optimized functions:\n";
PrintResult(function_group); PrintResult(function_group);
......
...@@ -49,7 +49,8 @@ class AutoTuner { ...@@ -49,7 +49,8 @@ class AutoTuner {
AutoTuner(const common::Target& target, hlir::framework::Graph* graph); AutoTuner(const common::Target& target, hlir::framework::Graph* graph);
// Initialize tuner with specific config and auxiliary objects. // Initialize tuner with specific config and auxiliary objects.
void Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler); void Initialize(const Config& config,
hlir::framework::GraphCompiler* graph_compiler);
// Perform the tuning process and return the final result // Perform the tuning process and return the final result
TuningResult Tune(const TuningOptions& options); TuningResult Tune(const TuningOptions& options);
......
...@@ -73,14 +73,16 @@ class TestAutoTuner : public ::testing::Test { ...@@ -73,14 +73,16 @@ class TestAutoTuner : public ::testing::Test {
// AutoTuner is combined with new IR Schedule // AutoTuner is combined with new IR Schedule
FLAGS_cinn_ir_schedule = true; FLAGS_cinn_ir_schedule = true;
std::unordered_set<std::string> fetch_ids; std::unordered_set<std::string> fetch_ids;
auto program = CreateAddReluProgram(); auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph); compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph); graph_compiler =
tuner = std::make_unique<AutoTuner>(target, graph.get()); std::make_unique<GraphCompiler>(target, compiled_scope, graph);
tuner = std::make_unique<AutoTuner>(target, graph.get());
} }
TuningResult InitializeAndTune(const AutoTuner::Config& config, const TuningOptions& options) { TuningResult InitializeAndTune(const AutoTuner::Config& config,
const TuningOptions& options) {
tuner->Initialize(config, graph_compiler.get()); tuner->Initialize(config, graph_compiler.get());
return tuner->Tune(options); return tuner->Tune(options);
} }
...@@ -108,7 +110,8 @@ class TestAutoTuner : public ::testing::Test { ...@@ -108,7 +110,8 @@ class TestAutoTuner : public ::testing::Test {
VLOG(6) << "Print lowered_funcs before building"; VLOG(6) << "Print lowered_funcs before building";
VLOG(6) << compile_options.lowered_funcs[0][0]; VLOG(6) << compile_options.lowered_funcs[0][0];
VLOG(6) << compile_options.lowered_funcs[1][0]; VLOG(6) << compile_options.lowered_funcs[1][0];
auto runtime_program = graph_compiler->Build(compile_options).runtime_program; auto runtime_program =
graph_compiler->Build(compile_options).runtime_program;
ASSERT_EQ(1, runtime_program->size()); ASSERT_EQ(1, runtime_program->size());
runtime_program->Execute(); runtime_program->Execute();
} }
...@@ -120,7 +123,7 @@ class TestAutoTuner : public ::testing::Test { ...@@ -120,7 +123,7 @@ class TestAutoTuner : public ::testing::Test {
TuningOptions tuning_options; TuningOptions tuning_options;
tuning_options.num_measure_trials = 0; tuning_options.num_measure_trials = 0;
auto result = InitializeAndTune(tuning_config, tuning_options); auto result = InitializeAndTune(tuning_config, tuning_options);
BasicCheckResult(result); BasicCheckResult(result);
ApplyTunedAndRun(result); ApplyTunedAndRun(result);
} }
...@@ -131,7 +134,7 @@ class TestAutoTuner : public ::testing::Test { ...@@ -131,7 +134,7 @@ class TestAutoTuner : public ::testing::Test {
tuning_config.task_schedule_strategy = "round_robin"; tuning_config.task_schedule_strategy = "round_robin";
TuningOptions tuning_options; TuningOptions tuning_options;
tuning_options.num_measure_trials = 4; tuning_options.num_measure_trials = 4;
tuning_options.num_samples_per_iteration = 2; tuning_options.num_samples_per_iteration = 2;
auto result = InitializeAndTune(tuning_config, tuning_options); auto result = InitializeAndTune(tuning_config, tuning_options);
......
...@@ -28,14 +28,15 @@ ...@@ -28,14 +28,15 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
float ExprCostModel::Predict(const ir::ModuleExpr& sample, const common::Target& target) const { float ExprCostModel::Predict(const ir::ModuleExpr& sample,
const common::Target& target) const {
if (trained_times_.load() == 0) { if (trained_times_.load() == 0) {
return SearchState::NOT_INIT_COST; return SearchState::NOT_INIT_COST;
} }
FeatureExtractor extractor; FeatureExtractor extractor;
Feature feature = extractor.Extract(sample, target); Feature feature = extractor.Extract(sample, target);
std::vector<float> feature_numbers = feature.ToFixedSizeVector(); std::vector<float> feature_numbers = feature.ToFixedSizeVector();
std::vector<float> pred = XgbCostModel::Predict({feature_numbers}); std::vector<float> pred = XgbCostModel::Predict({feature_numbers});
return pred[0]; return pred[0];
} }
...@@ -44,12 +45,13 @@ void ExprCostModel::Train(const std::vector<const ir::ModuleExpr*>& samples, ...@@ -44,12 +45,13 @@ void ExprCostModel::Train(const std::vector<const ir::ModuleExpr*>& samples,
const common::Target& target) { const common::Target& target) {
trained_times_.store(1); trained_times_.store(1);
size_t total_size = samples.size(); size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; CHECK_EQ(total_size, labels.size())
<< "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size); std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor; FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) { for (size_t i = 0; i < total_size; ++i) {
CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr"; CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr";
Feature feature = extractor.Extract(*samples[i], target); Feature feature = extractor.Extract(*samples[i], target);
train_feature_numbers[i] = feature.ToFixedSizeVector(); train_feature_numbers[i] = feature.ToFixedSizeVector();
} }
...@@ -61,12 +63,13 @@ void ExprCostModel::Update(const std::vector<const ir::ModuleExpr*>& samples, ...@@ -61,12 +63,13 @@ void ExprCostModel::Update(const std::vector<const ir::ModuleExpr*>& samples,
const common::Target& target) { const common::Target& target) {
++trained_times_; ++trained_times_;
size_t total_size = samples.size(); size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; CHECK_EQ(total_size, labels.size())
<< "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size); std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor; FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) { for (size_t i = 0; i < total_size; ++i) {
CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr"; CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr";
Feature feature = extractor.Extract(*samples[i], target); Feature feature = extractor.Extract(*samples[i], target);
train_feature_numbers[i] = feature.ToFixedSizeVector(); train_feature_numbers[i] = feature.ToFixedSizeVector();
} }
......
...@@ -29,7 +29,8 @@ namespace auto_schedule { ...@@ -29,7 +29,8 @@ namespace auto_schedule {
*/ */
class ExprCostModel : public XgbCostModel { class ExprCostModel : public XgbCostModel {
public: public:
virtual float Predict(const ir::ModuleExpr& sample, const common::Target& target) const; virtual float Predict(const ir::ModuleExpr& sample,
const common::Target& target) const;
void Train(const std::vector<const ir::ModuleExpr*>& samples, void Train(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels, const std::vector<float>& labels,
const common::Target& target); const common::Target& target);
......
...@@ -49,7 +49,8 @@ Feature::Feature(const common::Target& target) ...@@ -49,7 +49,8 @@ Feature::Feature(const common::Target& target)
parent_indices_(1, -1) {} parent_indices_(1, -1) {}
std::vector<float> Feature::ToFixedSizeVector() { std::vector<float> Feature::ToFixedSizeVector() {
std::vector<float> ret(LoopBlockFeature::kTotalSize + 1, 0); // LoopBlockFeature::kTotalSize plus 1 for target std::vector<float> ret(LoopBlockFeature::kTotalSize + 1,
0); // LoopBlockFeature::kTotalSize plus 1 for target
if (target_ == common::DefaultNVGPUTarget()) { if (target_ == common::DefaultNVGPUTarget()) {
ret[0] = 1; ret[0] = 1;
...@@ -58,13 +59,13 @@ std::vector<float> Feature::ToFixedSizeVector() { ...@@ -58,13 +59,13 @@ std::vector<float> Feature::ToFixedSizeVector() {
// loop[i] feature count should multiply iter_multi_num[i] // loop[i] feature count should multiply iter_multi_num[i]
std::vector<int> iter_multi_num; std::vector<int> iter_multi_num;
for (size_t i = 0; i < stack_encoded_feature_.size(); ++i) { for (size_t i = 0; i < stack_encoded_feature_.size(); ++i) {
int j = 1; int j = 1;
const LoopBlockFeature& loop_feature = stack_encoded_feature_[i]; const LoopBlockFeature& loop_feature = stack_encoded_feature_[i];
int loop_prod = 1; int loop_prod = 1;
int parent_prod = 1; int parent_prod = 1;
if (i != 0) { if (i != 0) {
parent_prod = iter_multi_num[parent_indices_[i]]; parent_prod = iter_multi_num[parent_indices_[i]];
loop_prod = parent_prod * loop_feature.loop_length; loop_prod = parent_prod * loop_feature.loop_length;
} }
iter_multi_num.push_back(loop_prod); iter_multi_num.push_back(loop_prod);
...@@ -165,11 +166,17 @@ void Feature::IntoLoopBlock() { ...@@ -165,11 +166,17 @@ void Feature::IntoLoopBlock() {
current_loop_block_index_ = stack_encoded_feature_.size() - 1; current_loop_block_index_ = stack_encoded_feature_.size() - 1;
} }
void Feature::ExitLoopBlock() { current_loop_block_index_ = parent_indices_[current_loop_block_index_]; } void Feature::ExitLoopBlock() {
current_loop_block_index_ = parent_indices_[current_loop_block_index_];
}
LoopBlockFeature& Feature::CurrentLoopBlock() { return stack_encoded_feature_[current_loop_block_index_]; } LoopBlockFeature& Feature::CurrentLoopBlock() {
return stack_encoded_feature_[current_loop_block_index_];
}
const LoopBlockFeature& Feature::CurrentLoopBlock() const { return stack_encoded_feature_[current_loop_block_index_]; } const LoopBlockFeature& Feature::CurrentLoopBlock() const {
return stack_encoded_feature_[current_loop_block_index_];
}
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -24,10 +24,18 @@ namespace cinn { ...@@ -24,10 +24,18 @@ namespace cinn {
namespace auto_schedule { namespace auto_schedule {
/* Loop feature enums */ /* Loop feature enums */
enum class ForOptimizeFeatureEnum : int { kNone, kGpuBind, kParallel, kUnroll, kVectorize }; enum class ForOptimizeFeatureEnum : int {
kNone,
kGpuBind,
kParallel,
kUnroll,
kVectorize
};
/* function to scale feature numbers */ /* function to scale feature numbers */
inline float slog(float x) { return x < 0 ? std::log2(-x + 1) : std::log2(x + 1); } inline float slog(float x) {
return x < 0 ? std::log2(-x + 1) : std::log2(x + 1);
}
class LoopBlockFeature { class LoopBlockFeature {
public: public:
...@@ -36,20 +44,20 @@ class LoopBlockFeature { ...@@ -36,20 +44,20 @@ class LoopBlockFeature {
// different bits, so we just distinguished int and float here // different bits, so we just distinguished int and float here
/* Arithmetic features */ /* Arithmetic features */
int float_add_or_sub = 0; int float_add_or_sub = 0;
int float_mul = 0; int float_mul = 0;
int float_div_or_mod = 0; int float_div_or_mod = 0;
int float_cmp = 0; int float_cmp = 0;
int float_math_func = 0; int float_math_func = 0;
int float_other_call = 0; // like simple assign, cast, etc. int float_other_call = 0; // like simple assign, cast, etc.
int int_add_or_sub = 0; int int_add_or_sub = 0;
int int_mul = 0; int int_mul = 0;
int int_div_or_mod = 0; int int_div_or_mod = 0;
int int_cmp = 0; int int_cmp = 0;
int int_math_func = 0; int int_math_func = 0;
int int_other_call = 0; // like simple assign, cast, etc. int int_other_call = 0; // like simple assign, cast, etc.
int bool_op = 0; int bool_op = 0;
int select_op = 0; int select_op = 0;
static constexpr int kArithSize = 6 * 2 + 2; static constexpr int kArithSize = 6 * 2 + 2;
...@@ -61,8 +69,8 @@ class LoopBlockFeature { ...@@ -61,8 +69,8 @@ class LoopBlockFeature {
* may be collect operand sizes (like alloc size, write size, or so) * may be collect operand sizes (like alloc size, write size, or so)
*/ */
int mem_alloc = 0; int mem_alloc = 0;
int mem_free = 0; int mem_free = 0;
int mem_read = 0; int mem_read = 0;
int mem_write = 0; int mem_write = 0;
static constexpr int kMemSize = 4; static constexpr int kMemSize = 4;
...@@ -71,16 +79,16 @@ class LoopBlockFeature { ...@@ -71,16 +79,16 @@ class LoopBlockFeature {
* Reduce and Broadcast features * Reduce and Broadcast features
*/ */
int float_reduce_sum_or_sub = 0; int float_reduce_sum_or_sub = 0;
int float_reduce_mul = 0; int float_reduce_mul = 0;
int float_reduce_div = 0; int float_reduce_div = 0;
int float_reduce_max_or_min = 0; int float_reduce_max_or_min = 0;
int float_broadcast = 0; int float_broadcast = 0;
int int_reduce_sum_or_sub = 0; int int_reduce_sum_or_sub = 0;
int int_reduce_mul = 0; int int_reduce_mul = 0;
int int_reduce_div = 0; int int_reduce_div = 0;
int int_reduce_max_or_min = 0; int int_reduce_max_or_min = 0;
int int_broadcast = 0; int int_broadcast = 0;
static constexpr int kReduceBroadcastSize = 10; static constexpr int kReduceBroadcastSize = 10;
...@@ -95,18 +103,20 @@ class LoopBlockFeature { ...@@ -95,18 +103,20 @@ class LoopBlockFeature {
/* Thread features if loop is optimized by GPU or CPU parallelism. /* Thread features if loop is optimized by GPU or CPU parallelism.
* Useless in other cases. * Useless in other cases.
*/ */
int len_blockIdx_x = 0; int len_blockIdx_x = 0;
int len_blockIdx_y = 0; int len_blockIdx_y = 0;
int len_blockIdx_z = 0; int len_blockIdx_z = 0;
int len_threadIdx_x = 0; int len_threadIdx_x = 0;
int len_threadIdx_y = 0; int len_threadIdx_y = 0;
int len_threadIdx_z = 0; int len_threadIdx_z = 0;
int len_vthread = 0; // length of virtual thread int len_vthread = 0; // length of virtual thread
int vectorize_factor = 0; int vectorize_factor = 0;
static constexpr int kThreadFeatureSize = 8; static constexpr int kThreadFeatureSize = 8;
static constexpr int kTotalSize = kArithSize + kMemSize + kReduceBroadcastSize + kOptApplySize + kThreadFeatureSize; static constexpr int kTotalSize = kArithSize + kMemSize +
kReduceBroadcastSize + kOptApplySize +
kThreadFeatureSize;
/* Non-feature attributes, used to maintain during feature_extractor */ /* Non-feature attributes, used to maintain during feature_extractor */
...@@ -158,10 +168,11 @@ class Feature { ...@@ -158,10 +168,11 @@ class Feature {
// some_compute_3 // some_compute_3
// } // }
// //
// We go through the code and push loops into stack, then the features are encoded as // We go through the code and push loops into stack, then the features are
// [loop_block_feature_0, loop_block_feature_1, loop_block_feature_2, loop_block_feature_3] // encoded as [loop_block_feature_0, loop_block_feature_1,
// where loop_block_feature_i stores the features of some_compute_i (such // loop_block_feature_2, loop_block_feature_3] where loop_block_feature_i
// as number of arithmetic operations) // stores the features of some_compute_i (such as number of arithmetic
// operations)
// //
// loop_block_feature_0.num_sub_loops = 2 // loop_block_feature_0.num_sub_loops = 2
// loop_block_feature_1.num_sub_loops = 1 // loop_block_feature_1.num_sub_loops = 1
......
...@@ -47,7 +47,8 @@ FeatureExtractor::FeatureExtractor() {} ...@@ -47,7 +47,8 @@ FeatureExtractor::FeatureExtractor() {}
void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); } void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); }
Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, const common::Target &target) { Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr,
const common::Target &target) {
feature_ = Feature(target); feature_ = Feature(target);
for (const ir::Expr &e : mod_expr.GetExprs()) { for (const ir::Expr &e : mod_expr.GetExprs()) {
Visit(&e); Visit(&e);
...@@ -85,19 +86,20 @@ VisitDoNothing(_BufferRange_); ...@@ -85,19 +86,20 @@ VisitDoNothing(_BufferRange_);
NotVisitExprFields(_Tensor_) NotVisitExprFields(_Tensor_)
#define VisitForDtypePattern(NodeType, member) \ #define VisitForDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \ void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ if (x->type() == common::F32() || x->type() == common::F16() || \
feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \ x->type() == common::F64()) { \
} else { \ feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \
feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \ } else { \
} \ feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \ } \
for (const Expr *e : sub_exprs) { \ std::vector<const Expr *> sub_exprs = x->expr_fields(); \
if (e->defined()) { \ for (const Expr *e : sub_exprs) { \
Visit(e); \ if (e->defined()) { \
} \ Visit(e); \
} \ } \
} \
} }
VisitForDtypePattern(Add, add_or_sub); VisitForDtypePattern(Add, add_or_sub);
...@@ -118,19 +120,21 @@ VisitForDtypePattern(PrimitiveNode, math_func); ...@@ -118,19 +120,21 @@ VisitForDtypePattern(PrimitiveNode, math_func);
VisitForDtypePattern(Cast, other_call); VisitForDtypePattern(Cast, other_call);
VisitForDtypePattern(Let, other_call); VisitForDtypePattern(Let, other_call);
#define VisitForMultiOperandsDtypePattern(NodeType, member) \ #define VisitForMultiOperandsDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \ void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ if (x->type() == common::F32() || x->type() == common::F16() || \
feature_.CurrentLoopBlock().float_##member += (x->operands().size() - 1); \ x->type() == common::F64()) { \
} else { \ feature_.CurrentLoopBlock().float_##member += \
feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \ (x->operands().size() - 1); \
} \ } else { \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \ feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \
for (const Expr *e : sub_exprs) { \ } \
if (e->defined()) { \ std::vector<const Expr *> sub_exprs = x->expr_fields(); \
Visit(e); \ for (const Expr *e : sub_exprs) { \
} \ if (e->defined()) { \
} \ Visit(e); \
} \
} \
} }
VisitForMultiOperandsDtypePattern(Sum, add_or_sub); VisitForMultiOperandsDtypePattern(Sum, add_or_sub);
...@@ -166,23 +170,24 @@ void FeatureExtractor::Visit(const For *x) { ...@@ -166,23 +170,24 @@ void FeatureExtractor::Visit(const For *x) {
LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock(); LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock();
if (x->min.is_constant() && x->extent.is_constant()) { if (x->min.is_constant() && x->extent.is_constant()) {
loop_feature.loop_length = (x->extent.get_constant() - x->min.get_constant()); loop_feature.loop_length =
(x->extent.get_constant() - x->min.get_constant());
} else { } else {
loop_feature.loop_length = -1; // -1 represents unknown loop_feature.loop_length = -1; // -1 represents unknown
} }
if (x->is_parallel()) { if (x->is_parallel()) {
loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kParallel; loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kParallel;
loop_feature.len_vthread = loop_feature.loop_length; loop_feature.len_vthread = loop_feature.loop_length;
} else if (x->is_unrolled()) { } else if (x->is_unrolled()) {
loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kUnroll; loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kUnroll;
} else if (x->is_vectorized()) { } else if (x->is_vectorized()) {
loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kVectorize; loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kVectorize;
loop_feature.vectorize_factor = x->vectorize_info().factor; loop_feature.vectorize_factor = x->vectorize_info().factor;
} else if (x->is_binded()) { } else if (x->is_binded()) {
loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kGpuBind; loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kGpuBind;
const BindInfo &bind_info = x->bind_info(); const BindInfo &bind_info = x->bind_info();
int offset = bind_info.offset; int offset = bind_info.offset;
if (bind_info.for_type == ForType::GPUBlock) { if (bind_info.for_type == ForType::GPUBlock) {
if (offset == 0) { if (offset == 0) {
loop_feature.len_blockIdx_x = loop_feature.loop_length; loop_feature.len_blockIdx_x = loop_feature.loop_length;
...@@ -223,13 +228,16 @@ void FeatureExtractor::Visit(const PolyFor *x) { ...@@ -223,13 +228,16 @@ void FeatureExtractor::Visit(const PolyFor *x) {
/* Visit for Reduce and Broadcast */ /* Visit for Reduce and Broadcast */
void FeatureExtractor::Visit(const Reduce *x) { void FeatureExtractor::Visit(const Reduce *x) {
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { if (x->type() == common::F32() || x->type() == common::F16() ||
x->type() == common::F64()) {
switch (x->reduce_type) { switch (x->reduce_type) {
case Reduce::ReduceType::kSum: case Reduce::ReduceType::kSum:
feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_sum_or_sub +=
x->type().lanes();
break; break;
case Reduce::ReduceType::kSub: case Reduce::ReduceType::kSub:
feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_sum_or_sub +=
x->type().lanes();
break; break;
case Reduce::ReduceType::kDiv: case Reduce::ReduceType::kDiv:
feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes();
...@@ -238,10 +246,12 @@ void FeatureExtractor::Visit(const Reduce *x) { ...@@ -238,10 +246,12 @@ void FeatureExtractor::Visit(const Reduce *x) {
feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes();
break; break;
case Reduce::ReduceType::kMax: case Reduce::ReduceType::kMax:
feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_max_or_min +=
x->type().lanes();
break; break;
case Reduce::ReduceType::kMin: case Reduce::ReduceType::kMin:
feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_max_or_min +=
x->type().lanes();
break; break;
} }
} else { } else {
......
...@@ -48,9 +48,10 @@ TEST(FeatureExtractor, SimpleAssign) { ...@@ -48,9 +48,10 @@ TEST(FeatureExtractor, SimpleAssign) {
ir::Tensor B = lang::Compute( ir::Tensor B = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); {M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B}); poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
ir::Expr ast_expr = funcs[0]->body; "SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr to test: " << ast_expr; VLOG(6) << "Expr to test: " << ast_expr;
std::vector<Expr> vec_ast{ast_expr}; std::vector<Expr> vec_ast{ast_expr};
...@@ -62,7 +63,8 @@ TEST(FeatureExtractor, SimpleAssign) { ...@@ -62,7 +63,8 @@ TEST(FeatureExtractor, SimpleAssign) {
std::vector<float> to_check = feature.ToFixedSizeVector(); std::vector<float> to_check = feature.ToFixedSizeVector();
ASSERT_EQ(to_check.size(), static_cast<size_t>(LoopBlockFeature::kTotalSize + 1)); ASSERT_EQ(to_check.size(),
static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
VLOG(6) << "Feature data before slog:"; VLOG(6) << "Feature data before slog:";
for (size_t i = 0; i < to_check.size(); ++i) { for (size_t i = 0; i < to_check.size(); ++i) {
VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1);
...@@ -77,9 +79,11 @@ TEST(FeatureExtractor, SimpleAssign) { ...@@ -77,9 +79,11 @@ TEST(FeatureExtractor, SimpleAssign) {
ASSERT_EQ(to_check[0], 0); ASSERT_EQ(to_check[0], 0);
#endif #endif
// mem_read // mem_read
ASSERT_EQ(to_check[17], slog(M.get_constant() * N.get_constant())); // mem_read ASSERT_EQ(to_check[17],
slog(M.get_constant() * N.get_constant())); // mem_read
// mem_write // mem_write
ASSERT_EQ(to_check[18], slog(M.get_constant() * N.get_constant())); // mem_write ASSERT_EQ(to_check[18],
slog(M.get_constant() * N.get_constant())); // mem_write
// non-opt loops, including root block // non-opt loops, including root block
ASSERT_EQ(to_check[29], slog(3)); ASSERT_EQ(to_check[29], slog(3));
} }
...@@ -101,16 +105,19 @@ TEST(FeatureExtractor, MatrixMultiply) { ...@@ -101,16 +105,19 @@ TEST(FeatureExtractor, MatrixMultiply) {
ir::Var k(K.as_int32(), "reduce_axis_k"); ir::Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = lang::Compute( ir::Tensor C = lang::Compute(
{M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); {M, N},
[&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = poly::CreateStages({C}); poly::StageMap stages = poly::CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true);
std::vector<Expr> vec_ast{funcs[0]->body}; std::vector<Expr> vec_ast{funcs[0]->body};
ir::ModuleExpr mod_expr(vec_ast); ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr); ir::IRSchedule ir_sch(mod_expr);
std::vector<ir::Expr> blocks = ir_sch.GetAllBlocks(); std::vector<ir::Expr> blocks = ir_sch.GetAllBlocks();
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]); std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
ir_sch.Bind(loops.back(), "threadIdx.x"); ir_sch.Bind(loops.back(), "threadIdx.x");
ir::Expr ast_expr = mod_expr.GetExprs()[0]; ir::Expr ast_expr = mod_expr.GetExprs()[0];
...@@ -121,7 +128,8 @@ TEST(FeatureExtractor, MatrixMultiply) { ...@@ -121,7 +128,8 @@ TEST(FeatureExtractor, MatrixMultiply) {
std::vector<float> to_check = feature.ToFixedSizeVector(); std::vector<float> to_check = feature.ToFixedSizeVector();
ASSERT_EQ(to_check.size(), static_cast<size_t>(LoopBlockFeature::kTotalSize + 1)); ASSERT_EQ(to_check.size(),
static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
std::unordered_set<size_t> non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37}; std::unordered_set<size_t> non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37};
for (size_t i = 0; i < to_check.size(); ++i) { for (size_t i = 0; i < to_check.size(); ++i) {
VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1);
...@@ -135,7 +143,7 @@ TEST(FeatureExtractor, MatrixMultiply) { ...@@ -135,7 +143,7 @@ TEST(FeatureExtractor, MatrixMultiply) {
#else #else
ASSERT_EQ(to_check[0], 0); ASSERT_EQ(to_check[0], 0);
#endif #endif
float out_loop = M.get_constant() * N.get_constant(); float out_loop = M.get_constant() * N.get_constant();
float total_loop = out_loop * K.get_constant(); float total_loop = out_loop * K.get_constant();
// float_mul // float_mul
ASSERT_EQ(to_check[1], slog(total_loop)); ASSERT_EQ(to_check[1], slog(total_loop));
......
...@@ -57,7 +57,8 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) { ...@@ -57,7 +57,8 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) {
Dtype* py_data = static_cast<Dtype*>(ret.mutable_data()); Dtype* py_data = static_cast<Dtype*>(ret.mutable_data());
for (size_t i = 0; i < vec.size(); ++i) { for (size_t i = 0; i < vec.size(); ++i) {
assert(vec[i].size() == shape[1] && "Sub vectors must have same size in VectorToNumpy"); assert(vec[i].size() == shape[1] &&
"Sub vectors must have same size in VectorToNumpy");
memcpy(py_data + (shape[1] * i), vec[i].data(), shape[1] * sizeof(Dtype)); memcpy(py_data + (shape[1] * i), vec[i].data(), shape[1] * sizeof(Dtype));
} }
return ret; return ret;
...@@ -71,19 +72,23 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) { ...@@ -71,19 +72,23 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) {
void AddDistPkgToPythonSysPath() { void AddDistPkgToPythonSysPath() {
pybind11::module sys_py_mod = pybind11::module::import("sys"); pybind11::module sys_py_mod = pybind11::module::import("sys");
// short version such as "3.7", "3.8", ... // short version such as "3.7", "3.8", ...
std::string py_short_version = sys_py_mod.attr("version").cast<std::string>().substr(0, 3); std::string py_short_version =
sys_py_mod.attr("version").cast<std::string>().substr(0, 3);
std::string site_pkg_str = "/usr/local/lib/python" + py_short_version + "/dist-packages"; std::string site_pkg_str =
"/usr/local/lib/python" + py_short_version + "/dist-packages";
sys_py_mod.attr("path").attr("append")(site_pkg_str); sys_py_mod.attr("path").attr("append")(site_pkg_str);
// TODO(zhhsplendid): warning to users if setuptools hasn't been installed // TODO(zhhsplendid): warning to users if setuptools hasn't been installed
DIR* site_pkg_dir = opendir(site_pkg_str.c_str()); DIR* site_pkg_dir = opendir(site_pkg_str.c_str());
if (site_pkg_dir != nullptr) { if (site_pkg_dir != nullptr) {
std::regex setuptool_regex("setuptools-.*-py" + py_short_version + "\\.egg"); std::regex setuptool_regex("setuptools-.*-py" + py_short_version +
"\\.egg");
struct dirent* entry = nullptr; struct dirent* entry = nullptr;
while ((entry = readdir(site_pkg_dir)) != nullptr) { while ((entry = readdir(site_pkg_dir)) != nullptr) {
if (std::regex_match(entry->d_name, setuptool_regex)) { if (std::regex_match(entry->d_name, setuptool_regex)) {
sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" + entry->d_name); sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" +
entry->d_name);
} }
} }
closedir(site_pkg_dir); closedir(site_pkg_dir);
...@@ -96,40 +101,49 @@ XgbCostModel::XgbCostModel() { ...@@ -96,40 +101,49 @@ XgbCostModel::XgbCostModel() {
if (previous == 0) { if (previous == 0) {
AddDistPkgToPythonSysPath(); AddDistPkgToPythonSysPath();
} }
xgb_module_ = pybind11::module::import("xgboost"); xgb_module_ = pybind11::module::import("xgboost");
xgb_booster_ = xgb_module_.attr("Booster")(); xgb_booster_ = xgb_module_.attr("Booster")();
} }
void XgbCostModel::Train(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) { void XgbCostModel::Train(const std::vector<std::vector<float>>& samples,
update_samples_ = samples; const std::vector<float>& labels) {
update_labels_ = labels; update_samples_ = samples;
update_labels_ = labels;
pybind11::array np_samples = VectorToNumpy<float>(samples); pybind11::array np_samples = VectorToNumpy<float>(samples);
pybind11::array np_labels = VectorToNumpy<float>(labels); pybind11::array np_labels = VectorToNumpy<float>(labels);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels);
xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); xgb_booster_ = xgb_module_.attr("train")(
pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
} }
std::vector<float> XgbCostModel::Predict(const std::vector<std::vector<float>>& samples) const { std::vector<float> XgbCostModel::Predict(
const std::vector<std::vector<float>>& samples) const {
pybind11::array np_samples = VectorToNumpy<float>(samples); pybind11::array np_samples = VectorToNumpy<float>(samples);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples); pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples);
pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix); pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix);
return py_result.cast<std::vector<float>>(); return py_result.cast<std::vector<float>>();
} }
void XgbCostModel::Update(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) { void XgbCostModel::Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) {
update_samples_.insert(update_samples_.end(), samples.begin(), samples.end()); update_samples_.insert(update_samples_.end(), samples.begin(), samples.end());
update_labels_.insert(update_labels_.end(), labels.begin(), labels.end()); update_labels_.insert(update_labels_.end(), labels.begin(), labels.end());
pybind11::array np_samples = VectorToNumpy<float>(update_samples_); pybind11::array np_samples = VectorToNumpy<float>(update_samples_);
pybind11::array np_labels = VectorToNumpy<float>(update_labels_); pybind11::array np_labels = VectorToNumpy<float>(update_labels_);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels);
xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); xgb_booster_ = xgb_module_.attr("train")(
pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
} }
void XgbCostModel::Save(const std::string& path) { xgb_booster_.attr("save_model")(pybind11::str(path)); } void XgbCostModel::Save(const std::string& path) {
xgb_booster_.attr("save_model")(pybind11::str(path));
}
void XgbCostModel::Load(const std::string& path) { xgb_booster_.attr("load_model")(pybind11::str(path)); } void XgbCostModel::Load(const std::string& path) {
xgb_booster_.attr("load_model")(pybind11::str(path));
}
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -47,11 +47,14 @@ class XgbCostModel : public CostModel { ...@@ -47,11 +47,14 @@ class XgbCostModel : public CostModel {
XgbCostModel(); XgbCostModel();
~XgbCostModel() = default; ~XgbCostModel() = default;
void Train(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) override; void Train(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) override;
std::vector<float> Predict(const std::vector<std::vector<float>>& samples) const override; std::vector<float> Predict(
const std::vector<std::vector<float>>& samples) const override;
void Update(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) override; void Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) override;
void Save(const std::string& path) override; void Save(const std::string& path) override;
......
...@@ -31,10 +31,11 @@ TEST(CostModel, Basic) { ...@@ -31,10 +31,11 @@ TEST(CostModel, Basic) {
srand(time(NULL)); srand(time(NULL));
int batch_size = 16; int batch_size = 16;
int feature_size = 8; int feature_size = 8;
std::vector<float> labels(batch_size, 1.0); std::vector<float> labels(batch_size, 1.0);
std::vector<std::vector<float>> samples(batch_size, std::vector<float>(feature_size)); std::vector<std::vector<float>> samples(batch_size,
std::vector<float>(feature_size));
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < feature_size; ++j) { for (int j = 0; j < feature_size; ++j) {
samples[i][j] = rand() % 10; samples[i][j] = rand() % 10;
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
bool TuningRecord::Compare::operator()(const TuningRecord& lhs, const TuningRecord& rhs) const { bool TuningRecord::Compare::operator()(const TuningRecord& lhs,
const TuningRecord& rhs) const {
return lhs.execution_cost < rhs.execution_cost; return lhs.execution_cost < rhs.execution_cost;
} }
...@@ -39,15 +40,18 @@ proto::TuningRecord TuningRecord::ToProto() const { ...@@ -39,15 +40,18 @@ proto::TuningRecord TuningRecord::ToProto() const {
return record_proto; return record_proto;
} }
Database::Database(int capacity_per_task) : capacity_per_task_(capacity_per_task) { Database::Database(int capacity_per_task)
CHECK_GT(capacity_per_task_, 0) << "capacity_per_task_ should be greater than 0"; : capacity_per_task_(capacity_per_task) {
CHECK_GT(capacity_per_task_, 0)
<< "capacity_per_task_ should be greater than 0";
} }
std::unique_ptr<Database> Database::Make(const DatabaseConfig& config) { std::unique_ptr<Database> Database::Make(const DatabaseConfig& config) {
if (config.type == DatabaseType::kMemory) { if (config.type == DatabaseType::kMemory) {
return std::make_unique<Database>(config.capacity_per_task); return std::make_unique<Database>(config.capacity_per_task);
} else if (config.type == DatabaseType::kJSONFile) { } else if (config.type == DatabaseType::kJSONFile) {
return std::make_unique<JSONFileDatabase>(config.capacity_per_task, config.record_file_path, true); return std::make_unique<JSONFileDatabase>(
config.capacity_per_task, config.record_file_path, true);
} }
LOG(FATAL) << "Unimplemented database type."; LOG(FATAL) << "Unimplemented database type.";
...@@ -81,13 +85,16 @@ std::vector<TuningRecord> Database::LookUp(const std::string& task_key) { ...@@ -81,13 +85,16 @@ std::vector<TuningRecord> Database::LookUp(const std::string& task_key) {
return results; return results;
} }
std::vector<TuningRecord> Database::GetTopK(const std::string& task_key, int k) { std::vector<TuningRecord> Database::GetTopK(const std::string& task_key,
int k) {
auto fit = key2record_.find(task_key); auto fit = key2record_.find(task_key);
if (fit == key2record_.end() || k <= 0) { if (fit == key2record_.end() || k <= 0) {
return {}; return {};
} }
if (k > capacity_per_task_) { if (k > capacity_per_task_) {
LOG(WARNING) << "Top k=" << k << " is greater than the capacity, will adjust k=" << capacity_per_task_; LOG(WARNING) << "Top k=" << k
<< " is greater than the capacity, will adjust k="
<< capacity_per_task_;
k = capacity_per_task_; k = capacity_per_task_;
} }
...@@ -103,10 +110,12 @@ std::vector<TuningRecord> Database::GetTopK(const std::string& task_key, int k) ...@@ -103,10 +110,12 @@ std::vector<TuningRecord> Database::GetTopK(const std::string& task_key, int k)
} }
size_t Database::Size() { size_t Database::Size() {
auto res = auto res = std::accumulate(key2record_.begin(),
std::accumulate(key2record_.begin(), key2record_.end(), size_t(0), [](size_t res, const auto& kv) -> size_t { key2record_.end(),
return std::move(res) + kv.second.size(); size_t(0),
}); [](size_t res, const auto& kv) -> size_t {
return std::move(res) + kv.second.size();
});
return res; return res;
} }
......
...@@ -39,7 +39,9 @@ struct TuningRecord { ...@@ -39,7 +39,9 @@ struct TuningRecord {
predicted_cost(record.predicted_cost()), predicted_cost(record.predicted_cost()),
trace(record.trace()), trace(record.trace()),
execution_cost(record.execution_cost()) {} execution_cost(record.execution_cost()) {}
TuningRecord(const std::string& task_key, const SearchState& state, double execution_cost) TuningRecord(const std::string& task_key,
const SearchState& state,
double execution_cost)
: task_key(task_key), : task_key(task_key),
predicted_cost(state->predicted_cost), predicted_cost(state->predicted_cost),
trace(state->ir_schedule.GetTraceDesc().ToProto()), trace(state->ir_schedule.GetTraceDesc().ToProto()),
...@@ -58,15 +60,15 @@ struct TuningRecord { ...@@ -58,15 +60,15 @@ struct TuningRecord {
enum class DatabaseType : int { kMemory, kJSONFile }; enum class DatabaseType : int { kMemory, kJSONFile };
struct DatabaseConfig { struct DatabaseConfig {
DatabaseType type = DatabaseType::kMemory; DatabaseType type = DatabaseType::kMemory;
int capacity_per_task = 2; int capacity_per_task = 2;
std::string record_file_path = "/tmp/tuning_record.json"; std::string record_file_path = "/tmp/tuning_record.json";
}; };
// A database supports insert or lookup historial tuning result with specified traits. // A database supports insert or lookup historial tuning result with specified
// It can be implemented with a concrete storage to save/load underlying data, // traits. It can be implemented with a concrete storage to save/load underlying
// such as memory, file, database server and so on, this base class can be regarded as // data, such as memory, file, database server and so on, this base class can be
// one using memory as its underlying storage medium. // regarded as one using memory as its underlying storage medium.
class Database { class Database {
public: public:
explicit Database(int capacity_per_task); explicit Database(int capacity_per_task);
...@@ -93,7 +95,9 @@ class Database { ...@@ -93,7 +95,9 @@ class Database {
void Insert(const TuningRecord& record); void Insert(const TuningRecord& record);
// map task_key to its records // map task_key to its records
std::unordered_map<std::string, std::multiset<TuningRecord, TuningRecord::Compare>> key2record_; std::unordered_map<std::string,
std::multiset<TuningRecord, TuningRecord::Compare>>
key2record_;
// the max number of candidates stored // the max number of candidates stored
const int capacity_per_task_; const int capacity_per_task_;
}; };
......
...@@ -57,8 +57,10 @@ TEST_F(TestDatabase, GetTopK) { ...@@ -57,8 +57,10 @@ TEST_F(TestDatabase, GetTopK) {
ASSERT_TRUE(test_db.GetTopK("k5", 2).empty()); ASSERT_TRUE(test_db.GetTopK("k5", 2).empty());
ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1); ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1);
test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0)); test_db.AddRecord(
test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0)); TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0));
test_db.AddRecord(
TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0));
auto records = test_db.GetTopK("k4", 3); auto records = test_db.GetTopK("k4", 3);
ASSERT_EQ(records.size(), 2); ASSERT_EQ(records.size(), 2);
......
...@@ -35,7 +35,8 @@ void AppendLineToFile(const std::string& file_path, const std::string& line) { ...@@ -35,7 +35,8 @@ void AppendLineToFile(const std::string& file_path, const std::string& line) {
} }
// read lines from a json file // read lines from a json file
std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool allow_new_file) { std::vector<std::string> ReadLinesFromFile(const std::string& file_path,
bool allow_new_file) {
std::ifstream is(file_path); std::ifstream is(file_path);
if (is.good()) { if (is.good()) {
std::vector<std::string> json_strs; std::vector<std::string> json_strs;
...@@ -51,20 +52,26 @@ std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool al ...@@ -51,20 +52,26 @@ std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool al
return {}; return {};
} }
JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file) JSONFileDatabase::JSONFileDatabase(int capacity_per_task,
const std::string& record_file_path,
bool allow_new_file)
: Database(capacity_per_task), record_file_path_(record_file_path) { : Database(capacity_per_task), record_file_path_(record_file_path) {
VLOG(3) << "Auto schedule will save/load tuning records on file:" << record_file_path; VLOG(3) << "Auto schedule will save/load tuning records on file:"
<< record_file_path;
auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file); auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file);
std::vector<cinn::auto_schedule::proto::TuningRecord> all_records_proto(json_lines.size()); std::vector<cinn::auto_schedule::proto::TuningRecord> all_records_proto(
json_lines.size());
// convert JSON string to proto object // convert JSON string to proto object
auto worker_fn = [this, &json_lines, &all_records_proto](int index) { auto worker_fn = [this, &json_lines, &all_records_proto](int index) {
cinn::auto_schedule::proto::TuningRecord record_proto; cinn::auto_schedule::proto::TuningRecord record_proto;
auto status = google::protobuf::util::JsonStringToMessage(json_lines[index], &record_proto); auto status = google::protobuf::util::JsonStringToMessage(json_lines[index],
&record_proto);
CHECK(status.ok()) << "Failed to parse JSON: " << json_lines[index]; CHECK(status.ok()) << "Failed to parse JSON: " << json_lines[index];
all_records_proto[index].Swap(&record_proto); all_records_proto[index].Swap(&record_proto);
}; };
utils::parallel_run(worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1); utils::parallel_run(
worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
...@@ -81,8 +88,10 @@ JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& rec ...@@ -81,8 +88,10 @@ JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& rec
std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) { std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) {
proto::TuningRecord record_proto = record.ToProto(); proto::TuningRecord record_proto = record.ToProto();
std::string json_string; std::string json_string;
auto status = google::protobuf::util::MessageToJsonString(record_proto, &json_string); auto status =
CHECK(status.ok()) << "Failed to serialize record to JSON, task key = " << record.task_key; google::protobuf::util::MessageToJsonString(record_proto, &json_string);
CHECK(status.ok()) << "Failed to serialize record to JSON, task key = "
<< record.task_key;
VLOG(4) << "json_string = \n" << json_string; VLOG(4) << "json_string = \n" << json_string;
return json_string; return json_string;
......
...@@ -19,16 +19,20 @@ ...@@ -19,16 +19,20 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
// JSONFileDatabase is a database implemented by JSON file to save/load underlying data. // JSONFileDatabase is a database implemented by JSON file to save/load
// underlying data.
class JSONFileDatabase : public Database { class JSONFileDatabase : public Database {
public: public:
/*! /*!
* \brief Build a JSONFileDatabase object from a json file. * \brief Build a JSONFileDatabase object from a json file.
* \param capacity_per_task The max number of candidates stored. * \param capacity_per_task The max number of candidates stored.
* \param record_file_path The path of the json file. * \param record_file_path The path of the json file.
* \param allow_new_file Whether to create new file when the given path is not found. * \param allow_new_file Whether to create new file when the given path is not
* found.
*/ */
JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file); JSONFileDatabase(int capacity_per_task,
const std::string& record_file_path,
bool allow_new_file);
~JSONFileDatabase() = default; ~JSONFileDatabase() = default;
// convert a TuningRecord object to string in JSON format // convert a TuningRecord object to string in JSON format
...@@ -46,7 +50,8 @@ class JSONFileDatabase : public Database { ...@@ -46,7 +50,8 @@ class JSONFileDatabase : public Database {
void AppendLineToFile(const std::string& file_path, const std::string& line); void AppendLineToFile(const std::string& file_path, const std::string& line);
// read lines from a json file // read lines from a json file
std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool allow_new_file = true); std::vector<std::string> ReadLinesFromFile(const std::string& file_path,
bool allow_new_file = true);
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -31,7 +31,8 @@ namespace cinn { ...@@ -31,7 +31,8 @@ namespace cinn {
namespace auto_schedule { namespace auto_schedule {
// Return lowerd ir AST for example functions used in this test // Return lowerd ir AST for example functions used in this test
std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape, const Target& target) { std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape,
const Target& target) {
CHECK(shape.size() == 2) << "shape should be 2"; CHECK(shape.size() == 2) << "shape should be 2";
std::vector<Expr> domain; std::vector<Expr> domain;
for (auto i = 0; i < shape.size(); ++i) { for (auto i = 0; i < shape.size(); ++i) {
...@@ -46,11 +47,13 @@ std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape, const T ...@@ -46,11 +47,13 @@ std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape, const T
C = Compute( C = Compute(
domain, [&B](Var i, Var j) { return B(i, j); }, "C"); domain, [&B](Var i, Var j) { return B(i, j); }, "C");
return cinn::lang::LowerVec("test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); return cinn::lang::LowerVec(
"test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
} }
// Create a new IRSchedule with copied ir::LoweredFunc AST // Create a new IRSchedule with copied ir::LoweredFunc AST
ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs, const std::string& task_key) { ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
const std::string& task_key) {
std::vector<Expr> exprs; std::vector<Expr> exprs;
for (auto&& func : lowered_funcs) { for (auto&& func : lowered_funcs) {
exprs.emplace_back(optim::IRCopy(func->body)); exprs.emplace_back(optim::IRCopy(func->body));
...@@ -63,7 +66,9 @@ ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs, ...@@ -63,7 +66,9 @@ ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
class TestJSONFileDatabase : public ::testing::Test { class TestJSONFileDatabase : public ::testing::Test {
public: public:
TestJSONFileDatabase() : record_file_path("/tmp/test_record.json"), test_db(2, record_file_path, true) {} TestJSONFileDatabase()
: record_file_path("/tmp/test_record.json"),
test_db(2, record_file_path, true) {}
void SetUp() override { lowered_funcs = LowerCompute({32, 32}, target); } void SetUp() override { lowered_funcs = LowerCompute({32, 32}, target); }
...@@ -91,55 +96,76 @@ class TestJSONFileDatabase : public ::testing::Test { ...@@ -91,55 +96,76 @@ class TestJSONFileDatabase : public ::testing::Test {
TEST_F(TestJSONFileDatabase, Serialize) { TEST_F(TestJSONFileDatabase, Serialize) {
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "test"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "test");
auto fused = ir_sch.Fuse("B", {0, 1}); auto fused = ir_sch.Fuse("B", {0, 1});
VLOG(3) << "after Fuse, Expr: " << fused; VLOG(3) << "after Fuse, Expr: " << fused;
TuningRecord record1("test", SearchState(std::move(ir_sch), 2.0), 1.0); TuningRecord record1("test", SearchState(std::move(ir_sch), 2.0), 1.0);
std::string str = test_db.RecordToJSON(record1); std::string str = test_db.RecordToJSON(record1);
VLOG(3) << "RecordToJSON: " << str; VLOG(3) << "RecordToJSON: " << str;
// Because the serialization of protobuf does not guarantee the order, we give all possible results. // Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std::string case1 = std::string case1 =
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":"
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" "{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":"
"\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}"; "name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}";
std::string case2 = std::string case2 =
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":"
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" "{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":"
"\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}"; "index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}";
EXPECT_EQ(true, str == case1 || str == case2); EXPECT_EQ(true, str == case1 || str == case2);
} }
TEST_F(TestJSONFileDatabase, SaveLoad) { TEST_F(TestJSONFileDatabase, SaveLoad) {
ir::IRSchedule ir_sch1 = MakeIRSchedule(lowered_funcs, "k1"); ir::IRSchedule ir_sch1 = MakeIRSchedule(lowered_funcs, "k1");
auto fused1 = ir_sch1.Fuse("B", {0, 1}); auto fused1 = ir_sch1.Fuse("B", {0, 1});
ir::IRSchedule ir_sch2 = MakeIRSchedule(lowered_funcs, "k2"); ir::IRSchedule ir_sch2 = MakeIRSchedule(lowered_funcs, "k2");
test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0)); test_db.AddRecord(
test_db.AddRecord(TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0)); TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0));
test_db.AddRecord(
TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0));
std::vector<std::string> strs = ReadLinesFromFile(record_file_path); std::vector<std::string> strs = ReadLinesFromFile(record_file_path);
ASSERT_EQ(strs.size(), 2); ASSERT_EQ(strs.size(), 2);
// Because the serialization of protobuf does not guarantee the order, we give all possible results. // Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std::string case1 = std::string case1 =
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":"
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" "{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":"
"\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}"; "name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}";
std::string case2 = std::string case2 =
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":"
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" "{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":"
"\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}"; "index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}";
EXPECT_EQ(true, strs[0] == case1 || strs[0] == case2); EXPECT_EQ(true, strs[0] == case1 || strs[0] == case2);
EXPECT_EQ(strs[1], "{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,\"trace\":{}}"); EXPECT_EQ(strs[1],
"{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,"
"\"trace\":{}}");
} }
TEST_F(TestJSONFileDatabase, Basic) { TEST_F(TestJSONFileDatabase, Basic) {
test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); "k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0)); "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0)); "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0)); test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0));
ASSERT_EQ(test_db.Size(), 6); ASSERT_EQ(test_db.Size(), 6);
auto records = test_db.LookUp("k3"); auto records = test_db.LookUp("k3");
...@@ -152,15 +178,24 @@ TEST_F(TestJSONFileDatabase, Basic) { ...@@ -152,15 +178,24 @@ TEST_F(TestJSONFileDatabase, Basic) {
} }
TEST_F(TestJSONFileDatabase, GetTopK) { TEST_F(TestJSONFileDatabase, GetTopK) {
test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); "k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0)); "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0)); "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0)); "k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0)); test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0));
auto records = test_db.GetTopK("k4", 3); auto records = test_db.GetTopK("k4", 3);
ASSERT_EQ(records.size(), 2); ASSERT_EQ(records.size(), 2);
...@@ -170,9 +205,11 @@ TEST_F(TestJSONFileDatabase, GetTopK) { ...@@ -170,9 +205,11 @@ TEST_F(TestJSONFileDatabase, GetTopK) {
TEST_F(TestJSONFileDatabase, Reload) { TEST_F(TestJSONFileDatabase, Reload) {
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "k1"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "k1");
auto fused = ir_sch.Fuse("B", {0, 1}); auto fused = ir_sch.Fuse("B", {0, 1});
test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0)); test_db.AddRecord(
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
auto records = test_db.LookUp("k1"); auto records = test_db.LookUp("k1");
ASSERT_EQ(records.size(), 1); ASSERT_EQ(records.size(), 1);
...@@ -184,11 +221,13 @@ TEST_F(TestJSONFileDatabase, Reload) { ...@@ -184,11 +221,13 @@ TEST_F(TestJSONFileDatabase, Reload) {
EXPECT_EQ(records[0].execution_cost, loaded_records[0].execution_cost); EXPECT_EQ(records[0].execution_cost, loaded_records[0].execution_cost);
EXPECT_EQ(records[0].predicted_cost, loaded_records[0].predicted_cost); EXPECT_EQ(records[0].predicted_cost, loaded_records[0].predicted_cost);
// check the equality of trace info between original TuningRecord and the loaded TuningRecord // check the equality of trace info between original TuningRecord and the
// loaded TuningRecord
const auto& lhs_trace = records[0].trace; const auto& lhs_trace = records[0].trace;
const auto& rhs_trace = loaded_records[0].trace; const auto& rhs_trace = loaded_records[0].trace;
google::protobuf::util::MessageDifferencer dif; google::protobuf::util::MessageDifferencer dif;
static const google::protobuf::Descriptor* descriptor = cinn::ir::proto::ScheduleDesc_Step::descriptor(); static const google::protobuf::Descriptor* descriptor =
cinn::ir::proto::ScheduleDesc_Step::descriptor();
dif.TreatAsSet(descriptor->FindFieldByName("attrs")); dif.TreatAsSet(descriptor->FindFieldByName("attrs"));
EXPECT_TRUE(dif.Compare(lhs_trace, rhs_trace)); EXPECT_TRUE(dif.Compare(lhs_trace, rhs_trace));
...@@ -203,8 +242,8 @@ TEST_F(TestJSONFileDatabase, Reload) { ...@@ -203,8 +242,8 @@ TEST_F(TestJSONFileDatabase, Reload) {
ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size()); ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size());
for (auto i = 0; i < lhs_exprs.size(); ++i) { for (auto i = 0; i < lhs_exprs.size(); ++i) {
std::string lhs = utils::GetStreamCnt(lhs_exprs.at(i)); std::string lhs = utils::GetStreamCnt(lhs_exprs.at(i));
std::string rhs = utils::GetStreamCnt(rhs_exprs.at(i)); std::string rhs = utils::GetStreamCnt(rhs_exprs.at(i));
size_t remove_prefix_len = 28; size_t remove_prefix_len = 28;
ASSERT_EQ(lhs.erase(0, remove_prefix_len), rhs.erase(0, remove_prefix_len)); ASSERT_EQ(lhs.erase(0, remove_prefix_len), rhs.erase(0, remove_prefix_len));
} }
......
...@@ -53,7 +53,8 @@ struct MeasureResult { ...@@ -53,7 +53,8 @@ struct MeasureResult {
// The result of building with input schedule // The result of building with input schedule
struct BuildResult { struct BuildResult {
// The scope that owns detail compilation infos of parameters in the runtime program // The scope that owns detail compilation infos of parameters in the runtime
// program
const hlir::framework::Scope* compiled_scope; const hlir::framework::Scope* compiled_scope;
// The executable program // The executable program
std::unique_ptr<hlir::framework::Program> runtime_program; std::unique_ptr<hlir::framework::Program> runtime_program;
...@@ -68,11 +69,13 @@ class ScheduleBuilder { ...@@ -68,11 +69,13 @@ class ScheduleBuilder {
virtual BuildResult Build(const MeasureInput& input) = 0; virtual BuildResult Build(const MeasureInput& input) = 0;
}; };
// This interface defines how to run the built result. Like above ScheduleBuilder, // This interface defines how to run the built result. Like above
// a runner shoule be implemented with not bound to a specific task. // ScheduleBuilder, a runner shoule be implemented with not bound to a specific
// task.
class ScheduleRunner { class ScheduleRunner {
public: public:
virtual MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) = 0; virtual MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) = 0;
}; };
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -62,22 +62,27 @@ class TestMeasurer : public ::testing::Test { ...@@ -62,22 +62,27 @@ class TestMeasurer : public ::testing::Test {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
std::unordered_set<std::string> fetch_ids; std::unordered_set<std::string> fetch_ids;
auto program = CreateAddReluProgram(); auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, scope, graph); graph_compiler = std::make_unique<GraphCompiler>(target, scope, graph);
TaskCreator task_creator; TaskCreator task_creator;
tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); const auto& dtype_dict =
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target); const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>(
"infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
inputs.reserve(tasks.size()); inputs.reserve(tasks.size());
for (int i = 0; i < tasks.size(); ++i) { for (int i = 0; i < tasks.size(); ++i) {
auto* task = &tasks[i]; auto* task = &tasks[i];
task->Initialize(shape_dict, dtype_dict, op_lowerer.get()); task->Initialize(shape_dict, dtype_dict, op_lowerer.get());
MeasureInput input; MeasureInput input;
input.task = task; input.task = task;
input.lowered_funcs = task->lowered_funcs; input.lowered_funcs = task->lowered_funcs;
inputs.emplace_back(input); inputs.emplace_back(input);
} }
...@@ -95,30 +100,37 @@ class ThrowExceptionRunner : public ScheduleRunner { ...@@ -95,30 +100,37 @@ class ThrowExceptionRunner : public ScheduleRunner {
struct Exception : public std::exception { struct Exception : public std::exception {
const char* what() const throw() { return "RunError"; } const char* what() const throw() { return "RunError"; }
}; };
MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override { throw Exception(); } MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) override {
throw Exception();
}
}; };
TEST_F(TestMeasurer, Basic) { TEST_F(TestMeasurer, Basic) {
auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get()); auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get());
auto runner = std::make_unique<SimpleRunner>(1); auto runner = std::make_unique<SimpleRunner>(1);
auto measurer = std::make_unique<ScheduleMeasurer>(builder.get(), runner.get()); auto measurer =
std::make_unique<ScheduleMeasurer>(builder.get(), runner.get());
std::vector<MeasureResult> results = measurer->Measure(inputs); std::vector<MeasureResult> results = measurer->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size()); ASSERT_EQ(inputs.size(), results.size());
} }
TEST_F(TestMeasurer, CatchException) { TEST_F(TestMeasurer, CatchException) {
auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get()); auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get());
auto runner = std::make_unique<SimpleRunner>(1); auto runner = std::make_unique<SimpleRunner>(1);
auto throw_builder = std::make_unique<ThrowExceptionBuilder>(); auto throw_builder = std::make_unique<ThrowExceptionBuilder>();
auto throw_runner = std::make_unique<ThrowExceptionRunner>(); auto throw_runner = std::make_unique<ThrowExceptionRunner>();
auto measurer_with_build_error = std::make_unique<ScheduleMeasurer>(throw_builder.get(), runner.get(), 2); auto measurer_with_build_error =
std::vector<MeasureResult> results = measurer_with_build_error->Measure(inputs); std::make_unique<ScheduleMeasurer>(throw_builder.get(), runner.get(), 2);
std::vector<MeasureResult> results =
measurer_with_build_error->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size()); ASSERT_EQ(inputs.size(), results.size());
EXPECT_EQ(results[0].error_msg, "Build failed, error: BuildError\n"); EXPECT_EQ(results[0].error_msg, "Build failed, error: BuildError\n");
// TODO(CtfGo): test parallel build after we support thread-safe compilation // TODO(CtfGo): test parallel build after we support thread-safe compilation
auto measurer_with_run_error = std::make_unique<ScheduleMeasurer>(builder.get(), throw_runner.get(), 1); auto measurer_with_run_error =
results = measurer_with_run_error->Measure(inputs); std::make_unique<ScheduleMeasurer>(builder.get(), throw_runner.get(), 1);
results = measurer_with_run_error->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size()); ASSERT_EQ(inputs.size(), results.size());
EXPECT_EQ(results[0].error_msg, "Run failed, error: RunError\n"); EXPECT_EQ(results[0].error_msg, "Run failed, error: RunError\n");
} }
......
...@@ -21,10 +21,13 @@ ...@@ -21,10 +21,13 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads) ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder,
ScheduleRunner* runner,
int num_threads)
: builder_(builder), runner_(runner), num_threads_(num_threads) {} : builder_(builder), runner_(runner), num_threads_(num_threads) {}
std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureInput>& inputs) { std::vector<MeasureResult> ScheduleMeasurer::Measure(
const std::vector<MeasureInput>& inputs) {
if (inputs.empty()) { if (inputs.empty()) {
LOG(WARNING) << "inputs is empty"; LOG(WARNING) << "inputs is empty";
return {}; return {};
...@@ -33,41 +36,49 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn ...@@ -33,41 +36,49 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn
std::vector<MeasureResult> results(inputs.size()); std::vector<MeasureResult> results(inputs.size());
// define how to build a candidate with the specified index // define how to build a candidate with the specified index
auto build_fn = [builder = builder_, &inputs, &build_results, &results](int index) { auto build_fn =
VLOG(6) << "Build candidate index: " << index; [builder = builder_, &inputs, &build_results, &results](int index) {
auto m_start = std::chrono::steady_clock::now(); VLOG(6) << "Build candidate index: " << index;
try { auto m_start = std::chrono::steady_clock::now();
build_results[index] = builder->Build(inputs[index]); try {
} catch (std::exception& e) { build_results[index] = builder->Build(inputs[index]);
results[index].error_msg = utils::StringFormat("Build failed, error: %s\n", e.what()); } catch (std::exception& e) {
} results[index].error_msg =
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - m_start); utils::StringFormat("Build failed, error: %s\n", e.what());
results[index].elapsed_time += static_cast<double>(time_span.count()); }
}; auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - m_start);
results[index].elapsed_time += static_cast<double>(time_span.count());
};
// define how to run a candidate with the specified index // define how to run a candidate with the specified index
auto run_fn = [runner = runner_, &inputs, &build_results, &results](int index) { auto run_fn =
VLOG(6) << "Run candidate index: " << index; [runner = runner_, &inputs, &build_results, &results](int index) {
auto m_start = std::chrono::steady_clock::now(); VLOG(6) << "Run candidate index: " << index;
try { auto m_start = std::chrono::steady_clock::now();
// if error occurred in building, then skip running try {
if (results[index].error_msg.empty()) { // if error occurred in building, then skip running
results[index] = runner->Run(inputs[index], build_results[index]); if (results[index].error_msg.empty()) {
} results[index] = runner->Run(inputs[index], build_results[index]);
} catch (std::exception& e) { }
results[index].error_msg = utils::StringFormat("Run failed, error: %s\n", e.what()); } catch (std::exception& e) {
} results[index].error_msg =
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - m_start); utils::StringFormat("Run failed, error: %s\n", e.what());
results[index].elapsed_time += static_cast<double>(time_span.count()); }
}; auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - m_start);
results[index].elapsed_time += static_cast<double>(time_span.count());
};
// measure a candidate by calling build and run successively // measure a candidate by calling build and run successively
auto measure_fn = [&build_fn, &run_fn](int index) { auto measure_fn = [&build_fn, &run_fn](int index) {
build_fn(index); build_fn(index);
run_fn(index); run_fn(index);
}; };
// default num_threads_ is 1 and in that case it will perform all measurements sequentially inplace. // default num_threads_ is 1 and in that case it will perform all measurements
utils::parallel_run(measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_); // sequentially inplace.
utils::parallel_run(
measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_);
VLOG(4) << "Measure " << inputs.size() << " candidates"; VLOG(4) << "Measure " << inputs.size() << " candidates";
return results; return results;
......
...@@ -25,7 +25,9 @@ namespace auto_schedule { ...@@ -25,7 +25,9 @@ namespace auto_schedule {
// which are building the input schedules and running the generated codes. // which are building the input schedules and running the generated codes.
class ScheduleMeasurer { class ScheduleMeasurer {
public: public:
ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads = 1); ScheduleMeasurer(ScheduleBuilder* builder,
ScheduleRunner* runner,
int num_threads = 1);
// Measure a batch of inputs and return all results once. // Measure a batch of inputs and return all results once.
std::vector<MeasureResult> Measure(const std::vector<MeasureInput>& inputs); std::vector<MeasureResult> Measure(const std::vector<MeasureInput>& inputs);
......
...@@ -19,20 +19,24 @@ namespace auto_schedule { ...@@ -19,20 +19,24 @@ namespace auto_schedule {
using hlir::framework::GraphCompiler; using hlir::framework::GraphCompiler;
SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler) : graph_compiler_(graph_compiler) {} SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler)
: graph_compiler_(graph_compiler) {}
BuildResult SimpleBuilder::Build(const MeasureInput& input) { BuildResult SimpleBuilder::Build(const MeasureInput& input) {
CHECK_NE(graph_compiler_, static_cast<GraphCompiler*>(nullptr)) << "empty handle to GraphCompiler"; CHECK_NE(graph_compiler_, static_cast<GraphCompiler*>(nullptr))
<< "empty handle to GraphCompiler";
GraphCompiler::CompileOptions compile_options; GraphCompiler::CompileOptions compile_options;
compile_options.groups.emplace_back(input.task->subgraph); compile_options.groups.emplace_back(input.task->subgraph);
compile_options.lowered_funcs.emplace_back(input.lowered_funcs); compile_options.lowered_funcs.emplace_back(input.lowered_funcs);
compile_options.remove_unused_variables = false; compile_options.remove_unused_variables = false;
VLOG(5) << "call GraphCompiler to Build with Graph::Group size=" << compile_options.groups.size() VLOG(5) << "call GraphCompiler to Build with Graph::Group size="
<< ", lowered_funcs group size=" << compile_options.lowered_funcs.size(); << compile_options.groups.size() << ", lowered_funcs group size="
GraphCompiler::CompilationResult compiled_result = graph_compiler_->Build(compile_options); << compile_options.lowered_funcs.size();
GraphCompiler::CompilationResult compiled_result =
graph_compiler_->Build(compile_options);
BuildResult build_result; BuildResult build_result;
build_result.compiled_scope = graph_compiler_->GetScope().get(); build_result.compiled_scope = graph_compiler_->GetScope().get();
build_result.runtime_program = std::move(compiled_result.runtime_program); build_result.runtime_program = std::move(compiled_result.runtime_program);
return build_result; return build_result;
} }
......
...@@ -35,48 +35,64 @@ using hlir::framework::Tensor; ...@@ -35,48 +35,64 @@ using hlir::framework::Tensor;
// Parameters that needs to be initialized to 0. // Parameters that needs to be initialized to 0.
// Key is the Op name, and value is the index of the input parameter in the Op. // Key is the Op name, and value is the index of the input parameter in the Op.
static const std::unordered_map<std::string, std::vector<int>> kInitWithZeroParams = { static const std::unordered_map<std::string, std::vector<int>>
{"lookup_table", {1}}, kInitWithZeroParams = {
{"gather", {1}}, {"lookup_table", {1}},
{"gather_nd", {1}}, {"gather", {1}},
{"scatter_assign", {2}}, {"gather_nd", {1}},
{"scatter_add", {2}}, {"scatter_assign", {2}},
{"scatter_add", {2}},
}; };
// Generate random value and populate them to the output address of memory // Generate random value and populate them to the output address of memory
static void PopulateRandomValue(const common::Type& type, const int numel, void* raw_ptr) { static void PopulateRandomValue(const common::Type& type,
const int numel,
void* raw_ptr) {
std::random_device seed; std::random_device seed;
std::default_random_engine engine(seed()); std::default_random_engine engine(seed());
if (type == common::Bool()) { if (type == common::Bool()) {
auto* fmt_ptr = reinterpret_cast<bool*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<bool*>(raw_ptr);
std::bernoulli_distribution dist(0.5); std::bernoulli_distribution dist(0.5);
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::I32()) { } else if (type == common::I32()) {
auto* fmt_ptr = reinterpret_cast<int*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<int*>(raw_ptr);
std::uniform_int_distribution<int> dist(std::numeric_limits<int>::min(), std::numeric_limits<int>::max()); std::uniform_int_distribution<int> dist(std::numeric_limits<int>::min(),
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::numeric_limits<int>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::I64()) { } else if (type == common::I64()) {
auto* fmt_ptr = reinterpret_cast<int64_t*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<int64_t*>(raw_ptr);
std::uniform_int_distribution<int64_t> dist(std::numeric_limits<int64_t>::min(), std::uniform_int_distribution<int64_t> dist(
std::numeric_limits<int64_t>::max()); std::numeric_limits<int64_t>::min(),
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::numeric_limits<int64_t>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::F32()) { } else if (type == common::F32()) {
auto* fmt_ptr = reinterpret_cast<float*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<float*>(raw_ptr);
std::uniform_real_distribution<float> dist(std::numeric_limits<float>::min(), std::numeric_limits<float>::max()); std::uniform_real_distribution<float> dist(
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::numeric_limits<float>::min(), std::numeric_limits<float>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else { } else {
CHECK_EQ(type.bytes(), 8) << "Unsupported type: " << type << ", type.bytes = " << type.bytes(); CHECK_EQ(type.bytes(), 8)
<< "Unsupported type: " << type << ", type.bytes = " << type.bytes();
auto* fmt_ptr = reinterpret_cast<uint8_t*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<uint8_t*>(raw_ptr);
std::uniform_int_distribution<uint8_t> dist(std::numeric_limits<uint8_t>::min(), std::uniform_int_distribution<uint8_t> dist(
std::numeric_limits<uint8_t>::max()); std::numeric_limits<uint8_t>::min(),
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::numeric_limits<uint8_t>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} }
} }
// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize the tensor with random value. // Initialize a tensor with 0 if init_with_zero == true, otherwise initialize
static void InitTensorData(Tensor tensor, const common::Target& target, bool init_with_zero) { // the tensor with random value.
int mem_size = tensor->shape().numel() * tensor->type().bytes(); static void InitTensorData(Tensor tensor,
const common::Target& target,
bool init_with_zero) {
int mem_size = tensor->shape().numel() * tensor->type().bytes();
auto* tensor_data = tensor->mutable_data(target, tensor->type()); auto* tensor_data = tensor->mutable_data(target, tensor->type());
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
if (target == common::DefaultNVGPUTarget()) { if (target == common::DefaultNVGPUTarget()) {
...@@ -101,17 +117,20 @@ static void InitTensorData(Tensor tensor, const common::Target& target, bool ini ...@@ -101,17 +117,20 @@ static void InitTensorData(Tensor tensor, const common::Target& target, bool ini
// Find all parameter names in the task corresponding to the MeasureInput // Find all parameter names in the task corresponding to the MeasureInput
// that need to be initialized to 0 when measuring. // that need to be initialized to 0 when measuring.
static std::unordered_set<std::string> ParamsNeedInitWithZero(const MeasureInput& input) { static std::unordered_set<std::string> ParamsNeedInitWithZero(
const MeasureInput& input) {
std::unordered_set<std::string> res; std::unordered_set<std::string> res;
std::vector<hlir::framework::Node*> nodes = input.task->subgraph->CollectNodes(); std::vector<hlir::framework::Node*> nodes =
input.task->subgraph->CollectNodes();
for (auto* node : nodes) { for (auto* node : nodes) {
if (kInitWithZeroParams.count(node->op()->name) != 0) { if (kInitWithZeroParams.count(node->op()->name) != 0) {
std::vector<int> param_idxs = kInitWithZeroParams.at(node->op()->name); std::vector<int> param_idxs = kInitWithZeroParams.at(node->op()->name);
const auto& inlinks = node->inlinks_in_order(); const auto& inlinks = node->inlinks_in_order();
for (int param_idx : param_idxs) { for (int param_idx : param_idxs) {
CHECK_GT(inlinks.size(), param_idx); CHECK_GT(inlinks.size(), param_idx);
auto& edge = inlinks.at(param_idx); auto& edge = inlinks.at(param_idx);
std::string param_name = edge->source()->as<hlir::framework::NodeData>()->id(); std::string param_name =
edge->source()->as<hlir::framework::NodeData>()->id();
VLOG(6) << "param needs to be init with 0: " << param_name; VLOG(6) << "param needs to be init with 0: " << param_name;
res.insert(param_name); res.insert(param_name);
} }
...@@ -128,17 +147,19 @@ SimpleRunner::SimpleRunner(int repeat_times) : repeat_times_(repeat_times) { ...@@ -128,17 +147,19 @@ SimpleRunner::SimpleRunner(int repeat_times) : repeat_times_(repeat_times) {
// Prepare execution arguments of all instructions to run, a argument // Prepare execution arguments of all instructions to run, a argument
// may be obtained from the input of measurement or allocating new buffer // may be obtained from the input of measurement or allocating new buffer
// with random value. // with random value.
std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureInput& input, std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(
const BuildResult& build_result, const MeasureInput& input,
hlir::framework::Scope* temp_scope) { const BuildResult& build_result,
hlir::framework::Scope* temp_scope) {
std::map<std::string, cinn_pod_value_t> result; std::map<std::string, cinn_pod_value_t> result;
const auto& target = input.task->target; const auto& target = input.task->target;
const auto* input_args = input.execution_args; const auto* input_args = input.execution_args;
const auto* compiled_scope = build_result.compiled_scope; const auto* compiled_scope = build_result.compiled_scope;
const auto& instructions = build_result.runtime_program->GetRunInstructions(); const auto& instructions = build_result.runtime_program->GetRunInstructions();
std::unordered_set<std::string> params_need_init_with_zero = ParamsNeedInitWithZero(input); std::unordered_set<std::string> params_need_init_with_zero =
ParamsNeedInitWithZero(input);
auto fill_arg_fn = [&](const std::string& param) { auto fill_arg_fn = [&](const std::string& param) {
VLOG(6) << "Filling argument:" << param; VLOG(6) << "Filling argument:" << param;
...@@ -169,7 +190,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI ...@@ -169,7 +190,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI
temp_tensor->Resize(compiled_tensor->shape()); temp_tensor->Resize(compiled_tensor->shape());
temp_tensor->set_type(compiled_tensor->type()); temp_tensor->set_type(compiled_tensor->type());
temp_tensor->mutable_data(target, compiled_tensor->type()); temp_tensor->mutable_data(target, compiled_tensor->type());
InitTensorData(temp_tensor, target, params_need_init_with_zero.count(param) != 0); InitTensorData(
temp_tensor, target, params_need_init_with_zero.count(param) != 0);
result.emplace(param, temp_tensor->buffer()); result.emplace(param, temp_tensor->buffer());
}; };
...@@ -186,7 +208,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI ...@@ -186,7 +208,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI
return result; return result;
} }
MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& build_result) { MeasureResult SimpleRunner::Run(const MeasureInput& input,
const BuildResult& build_result) {
MeasureResult result; MeasureResult result;
auto t_start = std::chrono::steady_clock::now(); auto t_start = std::chrono::steady_clock::now();
// prepare execution arguments // prepare execution arguments
...@@ -195,7 +218,7 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu ...@@ -195,7 +218,7 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu
auto execution_args = PrepareArgs(input, build_result, &temp_scope); auto execution_args = PrepareArgs(input, build_result, &temp_scope);
// Execute each instruction repeatedly and take the average as cost. // Execute each instruction repeatedly and take the average as cost.
result.execution_cost = 0; result.execution_cost = 0;
const auto& instructions = build_result.runtime_program->GetRunInstructions(); const auto& instructions = build_result.runtime_program->GetRunInstructions();
for (auto ct = 0; ct < instructions.size(); ++ct) { for (auto ct = 0; ct < instructions.size(); ++ct) {
auto&& instr = instructions.at(ct); auto&& instr = instructions.at(ct);
...@@ -209,16 +232,18 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu ...@@ -209,16 +232,18 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu
CUDA_CALL(cudaDeviceSynchronize()); CUDA_CALL(cudaDeviceSynchronize());
} }
#endif #endif
auto time_span = auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - run_start); std::chrono::steady_clock::now() - run_start);
auto cost_avg = static_cast<double>(time_span.count()) / repeat_times_; auto cost_avg = static_cast<double>(time_span.count()) / repeat_times_;
result.execution_cost += cost_avg; result.execution_cost += cost_avg;
} }
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - t_start); auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - t_start);
result.elapsed_time = static_cast<double>(time_span.count()); result.elapsed_time = static_cast<double>(time_span.count());
VLOG(4) << "A measurement done:repeat_times[" << repeat_times_ << "]total_elapsed_time[" << result.elapsed_time VLOG(4) << "A measurement done:repeat_times[" << repeat_times_
<< "]total_elapsed_time[" << result.elapsed_time
<< "]us,execution_cost[" << result.execution_cost << "]us"; << "]us,execution_cost[" << result.execution_cost << "]us";
return result; return result;
} }
......
...@@ -26,12 +26,14 @@ class SimpleRunner : public ScheduleRunner { ...@@ -26,12 +26,14 @@ class SimpleRunner : public ScheduleRunner {
public: public:
SimpleRunner(int repeat_times); SimpleRunner(int repeat_times);
MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override; MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) override;
private: private:
std::map<std::string, cinn_pod_value_t> PrepareArgs(const MeasureInput& input, std::map<std::string, cinn_pod_value_t> PrepareArgs(
const BuildResult& build_result, const MeasureInput& input,
hlir::framework::Scope* temp_scope); const BuildResult& build_result,
hlir::framework::Scope* temp_scope);
private: private:
// The repeat times of running instructions, // The repeat times of running instructions,
......
...@@ -53,15 +53,16 @@ class TestSimpleRunner : public ::testing::Test { ...@@ -53,15 +53,16 @@ class TestSimpleRunner : public ::testing::Test {
static frontend::Program CreateAddReluProgram(); static frontend::Program CreateAddReluProgram();
void SetUp() override { void SetUp() override {
std::unordered_set<std::string> fetch_ids; std::unordered_set<std::string> fetch_ids;
auto program = CreateAddReluProgram(); auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph); compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph); graph_compiler =
auto runtime_program = graph_compiler->Build(); std::make_unique<GraphCompiler>(target, compiled_scope, graph);
auto runtime_program = graph_compiler->Build();
const auto& instructions = runtime_program->GetRunInstructions(); const auto& instructions = runtime_program->GetRunInstructions();
ASSERT_EQ(1, instructions.size()); ASSERT_EQ(1, instructions.size());
build_result.compiled_scope = compiled_scope.get(); build_result.compiled_scope = compiled_scope.get();
build_result.runtime_program = std::move(runtime_program); build_result.runtime_program = std::move(runtime_program);
task = std::make_unique<TuneTask>(); task = std::make_unique<TuneTask>();
...@@ -71,7 +72,7 @@ class TestSimpleRunner : public ::testing::Test { ...@@ -71,7 +72,7 @@ class TestSimpleRunner : public ::testing::Test {
task->target = common::DefaultHostTarget(); task->target = common::DefaultHostTarget();
#endif #endif
task->subgraph = graph->fusion_groups.front(); task->subgraph = graph->fusion_groups.front();
input.task = task.get(); input.task = task.get();
} }
}; };
...@@ -115,18 +116,22 @@ TEST_F(TestSimpleRunner, TimeMeasured) { ...@@ -115,18 +116,22 @@ TEST_F(TestSimpleRunner, TimeMeasured) {
BuildResult build_result; BuildResult build_result;
build_result.compiled_scope = nullptr; build_result.compiled_scope = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions; std::vector<std::unique_ptr<Instruction>> instructions;
instructions.emplace_back( instructions.emplace_back(new Instruction(common::DefaultHostTarget(),
new Instruction(common::DefaultHostTarget(), nullptr, {}, {"empty_placeholder"}, "sleep_fn")); nullptr,
{},
{"empty_placeholder"},
"sleep_fn"));
instructions.back()->SetLoweredFunc(reinterpret_cast<void*>(sleep_fn)); instructions.back()->SetLoweredFunc(reinterpret_cast<void*>(sleep_fn));
instructions.back()->Finalize(); instructions.back()->Finalize();
build_result.runtime_program.reset(new hlir::framework::Program(nullptr, std::move(instructions))); build_result.runtime_program.reset(
new hlir::framework::Program(nullptr, std::move(instructions)));
// to skip the condition check of params in Instruction::PreparePodArgs // to skip the condition check of params in Instruction::PreparePodArgs
std::map<std::string, cinn_pod_value_t> preset_args; std::map<std::string, cinn_pod_value_t> preset_args;
preset_args.emplace("empty_placeholder", cinn_pod_value_t()); preset_args.emplace("empty_placeholder", cinn_pod_value_t());
input.execution_args = &preset_args; input.execution_args = &preset_args;
auto runner = std::make_unique<SimpleRunner>(2); auto runner = std::make_unique<SimpleRunner>(2);
MeasureResult measure_result = runner->Run(input, build_result); MeasureResult measure_result = runner->Run(input, build_result);
// because the kernel function will sleep 100 us, // because the kernel function will sleep 100 us,
// the cost time of execution and span in total must // the cost time of execution and span in total must
......
...@@ -22,10 +22,12 @@ ...@@ -22,10 +22,12 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
int ExtractNumThreads(const ir::IRSchedule& ir_schedule, const std::string& bind_axis) { int ExtractNumThreads(const ir::IRSchedule& ir_schedule,
const std::string& bind_axis) {
const ir::ScheduleDesc& trace = ir_schedule.GetTraceDesc(); const ir::ScheduleDesc& trace = ir_schedule.GetTraceDesc();
for (auto&& step : trace.Steps()) { for (auto&& step : trace.Steps()) {
if (step.type == "Bind" && step.attrs.find("thread_axis") != step.attrs.end() && if (step.type == "Bind" &&
step.attrs.find("thread_axis") != step.attrs.end() &&
absl::get<std::string>(step.attrs.at("thread_axis")) == bind_axis) { absl::get<std::string>(step.attrs.at("thread_axis")) == bind_axis) {
CHECK_EQ(step.inputs.at("loop").size(), 1); CHECK_EQ(step.inputs.at("loop").size(), 1);
return step.inputs.at("loop")[0].As<ir::For>()->extent.as_int32(); return step.inputs.at("loop")[0].As<ir::For>()->extent.as_int32();
...@@ -38,17 +40,21 @@ std::vector<std::string> FindCandidates(const ir::ScheduleDesc& trace) { ...@@ -38,17 +40,21 @@ std::vector<std::string> FindCandidates(const ir::ScheduleDesc& trace) {
std::vector<std::string> candidate_block_names; std::vector<std::string> candidate_block_names;
for (auto&& step : trace.Steps()) { for (auto&& step : trace.Steps()) {
if (step.type == "AnnotateIntAttr" && if (step.type == "AnnotateIntAttr" &&
absl::get<std::string>(step.attrs.at("key")) == ir::attr::cooperative_process) { absl::get<std::string>(step.attrs.at("key")) ==
ir::attr::cooperative_process) {
candidate_block_names.push_back( candidate_block_names.push_back(
step.inputs.at("block")[0].As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name); step.inputs.at("block")[0]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name);
} }
} }
return candidate_block_names; return candidate_block_names;
} }
bool CooperativeProcess::Apply(ir::IRSchedule* schedule) { bool CooperativeProcess::Apply(ir::IRSchedule* schedule) {
int num_threads = ExtractNumThreads(*schedule, "threadIdx.x"); int num_threads = ExtractNumThreads(*schedule, "threadIdx.x");
const ir::ScheduleDesc& trace = schedule->GetTraceDesc(); const ir::ScheduleDesc& trace = schedule->GetTraceDesc();
std::vector<std::string> candidate_block_names = FindCandidates(trace); std::vector<std::string> candidate_block_names = FindCandidates(trace);
for (auto&& candidate : candidate_block_names) { for (auto&& candidate : candidate_block_names) {
auto loop = schedule->GetLoops(candidate).back(); auto loop = schedule->GetLoops(candidate).back();
......
...@@ -20,8 +20,9 @@ namespace cinn { ...@@ -20,8 +20,9 @@ namespace cinn {
namespace auto_schedule { namespace auto_schedule {
/* /*
* @brief Rewrite the cooperative_process annotation to actually bind the loop on threadIdx. * @brief Rewrite the cooperative_process annotation to actually bind the loop
* This rule is used for collaborative data handling of multiple threads within the same block. * on threadIdx. This rule is used for collaborative data handling of multiple
* threads within the same block.
*/ */
class CooperativeProcess : public PostScheduleRule { class CooperativeProcess : public PostScheduleRule {
public: public:
......
...@@ -31,57 +31,75 @@ class TestCooperativeProcess : public TestAutoGenRuleBase { ...@@ -31,57 +31,75 @@ class TestCooperativeProcess : public TestAutoGenRuleBase {
}; };
TEST_F(TestCooperativeProcess, Matmul) { TEST_F(TestCooperativeProcess, Matmul) {
default_input_names = {"X", "Y"}; default_input_names = {"X", "Y"};
default_output_names = {"temp_matmul_out"}; default_output_names = {"temp_matmul_out"};
std::vector<int32_t> X_shape = {32, 32}; std::vector<int32_t> X_shape = {32, 32};
std::vector<int32_t> Y_shape = {32, 32}; std::vector<int32_t> Y_shape = {32, 32};
std::vector<int32_t> out_shape = {32, 32}; std::vector<int32_t> out_shape = {32, 32};
int num_blocks_y = 2; int num_blocks_y = 2;
int num_blocks_x = 2; int num_blocks_x = 2;
int num_threads_y = 8; int num_threads_y = 8;
int num_threads_x = 2; int num_threads_x = 2;
int steps_k = 8; int steps_k = 8;
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); frontend::Program matmul_op =
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed);
// split loops // split loops
std::vector<ir::Expr> loops = ir_schedule.GetLoops("temp_matmul_out"); std::vector<ir::Expr> loops = ir_schedule.GetLoops("temp_matmul_out");
std::vector<ir::Expr> k_loops = ir_schedule.Split(loops[2], {steps_k, -1}); std::vector<ir::Expr> k_loops = ir_schedule.Split(loops[2], {steps_k, -1});
std::vector<ir::Expr> j_loops = ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1}); std::vector<ir::Expr> j_loops =
std::vector<ir::Expr> i_loops = ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1}); ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1});
std::vector<ir::Expr> i_loops =
ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1});
// reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2 // reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.Reorder({loops[0], loops[3], loops[1], loops[4], loops[6], loops[7], loops[2], loops[5]}); ir_schedule.Reorder({loops[0],
loops[3],
loops[1],
loops[4],
loops[6],
loops[7],
loops[2],
loops[5]});
// fuse and bind // fuse and bind
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir::Expr i1_j1_fused = ir_schedule.Fuse({loops[2], loops[3]}); ir::Expr i1_j1_fused = ir_schedule.Fuse({loops[2], loops[3]});
ir::Expr i0_j0_fused = ir_schedule.Fuse({loops[0], loops[1]}); ir::Expr i0_j0_fused = ir_schedule.Fuse({loops[0], loops[1]});
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.Bind(loops[1], "threadIdx.x"); ir_schedule.Bind(loops[1], "threadIdx.x");
ir_schedule.Bind(loops[0], "blockIdx.x"); ir_schedule.Bind(loops[0], "blockIdx.x");
// cache read // cache read
ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out"); ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out");
ir::Expr X_cache_block = ir_schedule.CacheRead(out_block, 1, "shared"); ir::Expr X_cache_block = ir_schedule.CacheRead(out_block, 1, "shared");
std::string X_cache_block_name = std::string X_cache_block_name = X_cache_block.As<ir::ScheduleBlockRealize>()
X_cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name; ->schedule_block.As<ir::ScheduleBlock>()
->name;
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.ComputeAt(X_cache_block, loops[2]); ir_schedule.ComputeAt(X_cache_block, loops[2]);
std::vector<ir::Expr> X_cache_loops = ir_schedule.GetLoops(X_cache_block_name); std::vector<ir::Expr> X_cache_loops =
ir_schedule.GetLoops(X_cache_block_name);
ir_schedule.Fuse({X_cache_loops[3], X_cache_loops[4]}); ir_schedule.Fuse({X_cache_loops[3], X_cache_loops[4]});
ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name), ir::attr::cooperative_process, 0); ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name),
ir::attr::cooperative_process,
0);
out_block = ir_schedule.GetBlock("temp_matmul_out"); out_block = ir_schedule.GetBlock("temp_matmul_out");
ir::Expr Y_cache_block = ir_schedule.CacheRead(out_block, 2, "shared"); ir::Expr Y_cache_block = ir_schedule.CacheRead(out_block, 2, "shared");
std::string Y_cache_block_name = std::string Y_cache_block_name = Y_cache_block.As<ir::ScheduleBlockRealize>()
Y_cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name; ->schedule_block.As<ir::ScheduleBlock>()
->name;
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.ComputeAt(Y_cache_block, loops[2]); ir_schedule.ComputeAt(Y_cache_block, loops[2]);
std::vector<ir::Expr> Y_cache_loops = ir_schedule.GetLoops(Y_cache_block_name); std::vector<ir::Expr> Y_cache_loops =
ir_schedule.GetLoops(Y_cache_block_name);
ir_schedule.Fuse({Y_cache_loops[3], Y_cache_loops[4]}); ir_schedule.Fuse({Y_cache_loops[3], Y_cache_loops[4]});
ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name), ir::attr::cooperative_process, 0); ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name),
ir::attr::cooperative_process,
0);
// apply CooperativeProcess // apply CooperativeProcess
CooperativeProcess cooperative_process; CooperativeProcess cooperative_process;
...@@ -180,14 +198,15 @@ TEST_F(TestCooperativeProcess, Matmul) { ...@@ -180,14 +198,15 @@ TEST_F(TestCooperativeProcess, Matmul) {
ASSERT_EQ(ir, expected_ir); ASSERT_EQ(ir, expected_ir);
// build ir::Module and debug source code // build ir::Module and debug source code
auto ir_module = BuildIRModule(ir_schedule); auto ir_module = BuildIRModule(ir_schedule);
auto source_code = GenSourceCode(ir_module); auto source_code = GenSourceCode(ir_module);
VLOG(6) << "scheduled source code:\n" << source_code; VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision // execute and check precision
CheckResult( CheckResult(
GenExecutableKernel(ir_module), GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), GenExecutableKernel(BuildIRModule(MakeIRSchedule(
matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names, default_input_names,
default_output_names, default_output_names,
{X_shape, Y_shape}, {X_shape, Y_shape},
......
...@@ -29,37 +29,45 @@ static constexpr uint32_t kMaxBlocks = 256; ...@@ -29,37 +29,45 @@ static constexpr uint32_t kMaxBlocks = 256;
bool IsSpatialLoop(const ir::For* for_node) { bool IsSpatialLoop(const ir::For* for_node) {
if (for_node->for_type() != ir::ForType::Serial) return false; if (for_node->for_type() != ir::ForType::Serial) return false;
const auto& loop_var = for_node->loop_var; const auto& loop_var = for_node->loop_var;
// collect cases where the loop_var used in one of reduce axis in underneath ScheduleBlock // collect cases where the loop_var used in one of reduce axis in underneath
auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(for_node->body, [&loop_var](const Expr* x) { // ScheduleBlock
const auto* block_realize = x->As<ir::ScheduleBlockRealize>(); auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(
if (!block_realize) return false; for_node->body, [&loop_var](const Expr* x) {
const auto* block_realize = x->As<ir::ScheduleBlockRealize>();
const auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>(); if (!block_realize) return false;
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
CHECK_EQ(block_realize->iter_values.size(), schedule_block->iter_vars.size()); const auto* schedule_block =
for (int i = 0; i < block_realize->iter_values.size(); ++i) { block_realize->schedule_block.As<ir::ScheduleBlock>();
const ir::Var& iter_var = schedule_block->iter_vars[i]; CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
const ir::Expr& binding = block_realize->iter_values[i]; CHECK_EQ(block_realize->iter_values.size(),
if (iter_var->is_reduce_axis || iter_var->name.substr(0, 6) == "reduce") { schedule_block->iter_vars.size());
auto used_exprs = ir::CollectIRNodesWithoutTensor(binding, [&loop_var](const Expr* x) { for (int i = 0; i < block_realize->iter_values.size(); ++i) {
const ir::_Var_* var = x->As<ir::_Var_>(); const ir::Var& iter_var = schedule_block->iter_vars[i];
if (var && (x->same_as(loop_var) || var->name == loop_var->name)) { const ir::Expr& binding = block_realize->iter_values[i];
return true; if (iter_var->is_reduce_axis ||
iter_var->name.substr(0, 6) == "reduce") {
auto used_exprs = ir::CollectIRNodesWithoutTensor(
binding, [&loop_var](const Expr* x) {
const ir::_Var_* var = x->As<ir::_Var_>();
if (var &&
(x->same_as(loop_var) || var->name == loop_var->name)) {
return true;
}
return false;
});
if (!used_exprs.empty()) return true;
} }
return false; }
});
if (!used_exprs.empty()) return true;
}
}
return false; return false;
}); });
if (!used_for_reduce_axis.empty()) return false; if (!used_for_reduce_axis.empty()) return false;
return true; return true;
} }
// count the number of loops that can be binded from the input for_node to bottom // count the number of loops that can be binded from the input for_node to
// bottom
int CountLoopCanBinded(const ir::For* for_node) { int CountLoopCanBinded(const ir::For* for_node) {
int cnt = 0; int cnt = 0;
while (for_node) { while (for_node) {
...@@ -68,9 +76,11 @@ int CountLoopCanBinded(const ir::For* for_node) { ...@@ -68,9 +76,11 @@ int CountLoopCanBinded(const ir::For* for_node) {
cnt += 1; cnt += 1;
CHECK(for_node->body.defined() && for_node->body.As<ir::Block>()) << "Body is not defined"; CHECK(for_node->body.defined() && for_node->body.As<ir::Block>())
<< "Body is not defined";
const ir::Block* body = for_node->body.As<ir::Block>(); const ir::Block* body = for_node->body.As<ir::Block>();
// terminate when body of this loop has more than one statement or the body is not a ir::For node // terminate when body of this loop has more than one statement or the body
// is not a ir::For node
for_node = body->stmts.size() == 1 ? body->stmts[0].As<ir::For>() : nullptr; for_node = body->stmts.size() == 1 ? body->stmts[0].As<ir::For>() : nullptr;
} }
return cnt; return cnt;
...@@ -82,14 +92,18 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, ...@@ -82,14 +92,18 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule,
int max_blocks, int max_blocks,
int max_threads_per_block) { int max_threads_per_block) {
auto all_loops = ir_schedule->GetLoops(block_name); auto all_loops = ir_schedule->GetLoops(block_name);
CHECK_LE(num_loops_to_bind, all_loops.size()) << "The number of loops to be bind is greater than size of all_loops"; CHECK_LE(num_loops_to_bind, all_loops.size())
// check whether it is the case that threadIdx has been binded but blockIdx not, << "The number of loops to be bind is greater than size of all_loops";
// the threadIdx can only be binded in the first loop after num_loops_to_bind loops // check whether it is the case that threadIdx has been binded but blockIdx
// because we has excluded other cases in CountLoopCanBinded // not, the threadIdx can only be binded in the first loop after
// num_loops_to_bind loops because we has excluded other cases in
// CountLoopCanBinded
bool gpu_thread_has_binded = bool gpu_thread_has_binded =
num_loops_to_bind < all_loops.size() && all_loops[num_loops_to_bind].As<ir::For>()->is_gpu_thread_binded(); num_loops_to_bind < all_loops.size() &&
Expr fused_loop = ir_schedule->Fuse({all_loops.begin(), all_loops.begin() + num_loops_to_bind}); all_loops[num_loops_to_bind].As<ir::For>()->is_gpu_thread_binded();
int32_t extent = fused_loop.As<ir::For>()->extent.as_int32(); Expr fused_loop = ir_schedule->Fuse(
{all_loops.begin(), all_loops.begin() + num_loops_to_bind});
int32_t extent = fused_loop.As<ir::For>()->extent.as_int32();
if (gpu_thread_has_binded) { if (gpu_thread_has_binded) {
ir_schedule->Bind(fused_loop, "blockIdx.x"); ir_schedule->Bind(fused_loop, "blockIdx.x");
return; return;
...@@ -106,7 +120,8 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, ...@@ -106,7 +120,8 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule,
ir_schedule->Bind(splits[0], "blockIdx.x"); ir_schedule->Bind(splits[0], "blockIdx.x");
ir_schedule->Bind(splits[1], "threadIdx.x"); ir_schedule->Bind(splits[1], "threadIdx.x");
} else { } else {
auto splits = ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block}); auto splits =
ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block});
CHECK_EQ(splits.size(), 3); CHECK_EQ(splits.size(), 3);
ir_schedule->Reorder({splits[1], splits[2], splits[0]}); ir_schedule->Reorder({splits[1], splits[2], splits[0]});
all_loops = ir_schedule->GetLoops(block_name); all_loops = ir_schedule->GetLoops(block_name);
...@@ -126,31 +141,38 @@ RuleApplyType AutoBind::Init(ir::IRSchedule* ir_schedule) { ...@@ -126,31 +141,38 @@ RuleApplyType AutoBind::Init(ir::IRSchedule* ir_schedule) {
} }
num_applicable_ = applicable_schedule_blocks_.size(); num_applicable_ = applicable_schedule_blocks_.size();
VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_; VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
void AutoBind::Apply(int index) { void AutoBind::Apply(int index) {
CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index; CHECK_LT(index, applicable_schedule_blocks_.size())
<< "invalid apply index:" << index;
auto applied_block = applicable_schedule_blocks_.at(index); auto applied_block = applicable_schedule_blocks_.at(index);
auto all_loops = ir_schedule_->GetLoops(applied_block); auto all_loops = ir_schedule_->GetLoops(applied_block);
BindGPUIndex(ir_schedule_, BindGPUIndex(ir_schedule_,
applied_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name, applied_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name,
CountLoopCanBinded(all_loops[0].As<ir::For>()), CountLoopCanBinded(all_loops[0].As<ir::For>()),
kMaxBlocks, kMaxBlocks,
target_->max_num_threads()); target_->max_num_threads());
return; return;
} }
RuleApplyType AutoBind::AnalyseApplyType(SearchState state, const std::string& block_name) const { RuleApplyType AutoBind::AnalyseApplyType(SearchState state,
const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name); Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto all_loops = state->ir_schedule.GetLoops(block_expr); auto all_loops = state->ir_schedule.GetLoops(block_expr);
return CountLoopCanBinded(all_loops[0].As<ir::For>()) > 0 ? RuleApplyType::kApplyAndPruneOtherRules return CountLoopCanBinded(all_loops[0].As<ir::For>()) > 0
: RuleApplyType::kCannotApply; ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
std::vector<SearchState> AutoBind::ApplyOnBlock(SearchState state, const std::string& block_name) { std::vector<SearchState> AutoBind::ApplyOnBlock(SearchState state,
const std::string& block_name) {
SearchState new_state = state.Copy(); SearchState new_state = state.Copy();
auto all_loops = state->ir_schedule.GetLoops(block_name); auto all_loops = state->ir_schedule.GetLoops(block_name);
BindGPUIndex(&new_state->ir_schedule, BindGPUIndex(&new_state->ir_schedule,
block_name, block_name,
CountLoopCanBinded(all_loops[0].As<ir::For>()), CountLoopCanBinded(all_loops[0].As<ir::For>()),
......
...@@ -36,9 +36,11 @@ class AutoBind : public AutoGenRule { ...@@ -36,9 +36,11 @@ class AutoBind : public AutoGenRule {
std::string GetRuleName() const override { return "AutoBind"; } std::string GetRuleName() const override { return "AutoBind"; }
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override; std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private: private:
std::vector<Expr> applicable_schedule_blocks_; std::vector<Expr> applicable_schedule_blocks_;
......
...@@ -28,17 +28,19 @@ ...@@ -28,17 +28,19 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
static constexpr uint32_t kMaxBlocks = 256; static constexpr uint32_t kMaxBlocks = 256;
static constexpr uint32_t kMaxThreadsPerBlock = 1024; static constexpr uint32_t kMaxThreadsPerBlock = 1024;
class TestAutoBind : public TestAutoGenRuleBase { class TestAutoBind : public TestAutoGenRuleBase {
public: public:
std::vector<std::string> default_input_names = {"X", "Y"}; std::vector<std::string> default_input_names = {"X", "Y"};
std::vector<std::string> default_output_names = {"temp_matmul_out"}; std::vector<std::string> default_output_names = {"temp_matmul_out"};
void TestApplyOnElementWiseAdd(const std::vector<int>& shape, const std::string& block_name) { void TestApplyOnElementWiseAdd(const std::vector<int>& shape,
const std::string& block_name) {
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
auto test_program = tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}}); auto test_program =
tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}});
// construct input parameter // construct input parameter
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program); ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
...@@ -48,15 +50,17 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -48,15 +50,17 @@ class TestAutoBind : public TestAutoGenRuleBase {
// apply // apply
AutoBind auto_bind(target_); AutoBind auto_bind(target_);
ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name), RuleApplyType::kApplyAndPruneOtherRules); ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name),
auto result = auto_bind.ApplyOnBlock(state, block_name)[0]; RuleApplyType::kApplyAndPruneOtherRules);
auto result = auto_bind.ApplyOnBlock(state, block_name)[0];
std::vector<ir::Expr> exprs = result->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs = result->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "AutoBind applied Expr: " << exprs[0]; VLOG(6) << "AutoBind applied Expr: " << exprs[0];
// check bind result // check bind result
auto all_loops = result->ir_schedule.GetLoops(block_name); auto all_loops = result->ir_schedule.GetLoops(block_name);
int total_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); int total_num =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
if (total_num <= kMaxThreadsPerBlock) { if (total_num <= kMaxThreadsPerBlock) {
ASSERT_EQ(all_loops.size(), 1); ASSERT_EQ(all_loops.size(), 1);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), total_num); EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), total_num);
...@@ -64,27 +68,33 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -64,27 +68,33 @@ class TestAutoBind : public TestAutoGenRuleBase {
} else if (total_num <= kMaxBlocks * kMaxThreadsPerBlock) { } else if (total_num <= kMaxBlocks * kMaxThreadsPerBlock) {
ASSERT_EQ(all_loops.size(), 2); ASSERT_EQ(all_loops.size(), 2);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(),
static_cast<int32_t>(std::ceil(double(total_num) / kMaxThreadsPerBlock))); static_cast<int32_t>(
std::ceil(double(total_num) / kMaxThreadsPerBlock)));
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded()); EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded());
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(), kMaxThreadsPerBlock); EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(),
kMaxThreadsPerBlock);
EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded()); EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded());
} else { } else {
ASSERT_EQ(all_loops.size(), 3); ASSERT_EQ(all_loops.size(), 3);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), kMaxBlocks); EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), kMaxBlocks);
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded()); EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded());
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(), kMaxThreadsPerBlock); EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(),
kMaxThreadsPerBlock);
EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded()); EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded());
EXPECT_EQ(all_loops[2].As<ir::For>()->extent.as_int32(), EXPECT_EQ(all_loops[2].As<ir::For>()->extent.as_int32(),
static_cast<int32_t>(std::ceil(double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock)))); static_cast<int32_t>(std::ceil(
double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock))));
EXPECT_FALSE(all_loops[2].As<ir::For>()->is_binded()); EXPECT_FALSE(all_loops[2].As<ir::For>()->is_binded());
} }
// build and run // build and run
auto ir_module = BuildIRModule(result->ir_schedule); auto ir_module = BuildIRModule(result->ir_schedule);
auto source_code = GenSourceCode(ir_module); auto source_code = GenSourceCode(ir_module);
VLOG(6) << "Optimized source code:\n" << source_code; VLOG(6) << "Optimized source code:\n" << source_code;
auto manual_ir_module = BuildIRModule(MakeIRSchedule(test_program, /* apply_manual_schedule*/ true)); auto manual_ir_module = BuildIRModule(
VLOG(6) << "Manual-schedule compiled source code:\n" << GenSourceCode(manual_ir_module); MakeIRSchedule(test_program, /* apply_manual_schedule*/ true));
VLOG(6) << "Manual-schedule compiled source code:\n"
<< GenSourceCode(manual_ir_module);
CheckResult(GenExecutableKernel(ir_module), CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(manual_ir_module), GenExecutableKernel(manual_ir_module),
default_input_names, default_input_names,
...@@ -97,16 +107,20 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -97,16 +107,20 @@ class TestAutoBind : public TestAutoGenRuleBase {
TEST_F(TestAutoBind, AnalyseApplyType) { TEST_F(TestAutoBind, AnalyseApplyType) {
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}})); ir::IRSchedule ir_schedule = MakeIRSchedule(
tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}}));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
AutoBind auto_bind(target_); AutoBind auto_bind(target_);
const std::string& applied_block_name = default_output_names.back(); const std::string& applied_block_name = default_output_names.back();
// outer two loops of initial Expr are spatial loops, so it can be applied // outer two loops of initial Expr are spatial loops, so it can be applied
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name),
RuleApplyType::kApplyAndPruneOtherRules);
state->ir_schedule.Fuse(applied_block_name, {0, 1}); state->ir_schedule.Fuse(applied_block_name, {0, 1});
state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0], "threadIdx.x"); state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0],
"threadIdx.x");
// after fuse and bind, there is no loops to be binded. // after fuse and bind, there is no loops to be binded.
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kCannotApply); EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name),
RuleApplyType::kCannotApply);
} }
TEST_F(TestAutoBind, ApplyOnBlock) { TEST_F(TestAutoBind, ApplyOnBlock) {
......
...@@ -27,12 +27,16 @@ namespace auto_schedule { ...@@ -27,12 +27,16 @@ namespace auto_schedule {
AutoGenRule::AutoGenRule(const common::Target& target) : target_(&target) {} AutoGenRule::AutoGenRule(const common::Target& target) : target_(&target) {}
int AutoGenRule::NumberApplicable() const { int AutoGenRule::NumberApplicable() const {
CHECK_GE(num_applicable_, 0) << "Call " << GetRuleName() << "::NumberApplicable() without initialization."; CHECK_GE(num_applicable_, 0)
<< "Call " << GetRuleName()
<< "::NumberApplicable() without initialization.";
return num_applicable_; return num_applicable_;
} }
void AutoGenRule::ApplyRandomly() { void AutoGenRule::ApplyRandomly() {
CHECK_GT(num_applicable_, 0) << "Call " << GetRuleName() << "::ApplyRandomly() with NumberApplicable() == 0"; CHECK_GT(num_applicable_, 0)
<< "Call " << GetRuleName()
<< "::ApplyRandomly() with NumberApplicable() == 0";
int index = rand() % num_applicable_; int index = rand() % num_applicable_;
return Apply(index); return Apply(index);
} }
......
...@@ -29,15 +29,18 @@ enum class RuleApplyType : int { ...@@ -29,15 +29,18 @@ enum class RuleApplyType : int {
// This rule cannot be applied to ModuleExpr. // This rule cannot be applied to ModuleExpr.
kCannotApply = 0, kCannotApply = 0,
// This rule can be applied to ModuleExpr, // This rule can be applied to ModuleExpr,
// and the original ModuleExpr will be retained for branching with other rules. // and the original ModuleExpr will be retained for branching with other
// rules.
kApply = 1, kApply = 1,
// This rule can be applied, but the original ModuleExpr will be deleted, // This rule can be applied, but the original ModuleExpr will be deleted,
// so the branches with other rules applied on the original ModuleExpr will be pruned. // so the branches with other rules applied on the original ModuleExpr will be
// pruned.
kApplyAndPruneOtherRules = 2, kApplyAndPruneOtherRules = 2,
}; };
/** /**
* Base class for rules of auto-generating schedule (like Ansor's sketch generation) * Base class for rules of auto-generating schedule (like Ansor's sketch
* generation)
* *
*/ */
class AutoGenRule { class AutoGenRule {
...@@ -46,7 +49,8 @@ class AutoGenRule { ...@@ -46,7 +49,8 @@ class AutoGenRule {
~AutoGenRule() = default; ~AutoGenRule() = default;
// Initialize the AutoGenRule, it must be called before further actions. // Initialize the AutoGenRule, it must be called before further actions.
// Returns false if the rule cannot be applied on the mod_expr, true otherwise. // Returns false if the rule cannot be applied on the mod_expr, true
// otherwise.
virtual RuleApplyType Init(ir::IRSchedule* ir_schedule) = 0; virtual RuleApplyType Init(ir::IRSchedule* ir_schedule) = 0;
// CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so // CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so
...@@ -65,11 +69,15 @@ class AutoGenRule { ...@@ -65,11 +69,15 @@ class AutoGenRule {
// Returns the name of the rule, used for debug. // Returns the name of the rule, used for debug.
virtual std::string GetRuleName() const = 0; virtual std::string GetRuleName() const = 0;
// Analyze the ApplyType of the rule used for a block determined by a specific SearchState and block name // Analyze the ApplyType of the rule used for a block determined by a specific
virtual RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const = 0; // SearchState and block name
virtual RuleApplyType AnalyseApplyType(
SearchState state, const std::string& block_name) const = 0;
// Apply the rule to a block determined by a specific SearchState and block name // Apply the rule to a block determined by a specific SearchState and block
virtual std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) = 0; // name
virtual std::vector<SearchState> ApplyOnBlock(
SearchState state, const std::string& block_name) = 0;
protected: protected:
// number of ScheduleBlock that can apply this auto gen rule // number of ScheduleBlock that can apply this auto gen rule
......
...@@ -34,31 +34,38 @@ ...@@ -34,31 +34,38 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
AutoInline::AutoInline(const common::Target& target, const std::unordered_set<std::string>& no_inline_output_names) AutoInline::AutoInline(
const common::Target& target,
const std::unordered_set<std::string>& no_inline_output_names)
: AutoGenRule(target), no_inline_output_names_(no_inline_output_names) {} : AutoGenRule(target), no_inline_output_names_(no_inline_output_names) {}
bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As<ir::ScheduleBlockRealize>(); ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); const ir::ScheduleBlockRealize* sche_block_realize =
ir::Expr compute_body = sche_block->body; sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr); const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
ir::Expr compute_body = sche_block->body;
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);
// Check the schedule block to be inlined is not a reduce tensor. // Check the schedule block to be inlined is not a reduce tensor.
std::set<ir::Expr> find_store = std::set<ir::Expr> find_store = ir::CollectIRNodesWithoutTensor(
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { return x->As<ir::Store>(); }); compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) { if (find_store.size() != 1UL) {
return false; return false;
} }
ir::Expr tensor_expr = (*find_store.begin()).As<ir::Store>()->tensor; ir::Expr tensor_expr = (*find_store.begin()).As<ir::Store>()->tensor;
ir::Tensor tensor = tensor_expr.as_tensor_ref(); ir::Tensor tensor = tensor_expr.as_tensor_ref();
if (tensor->is_reduce_tensor()) { if (tensor->is_reduce_tensor()) {
return false; return false;
} }
// LoweredFunc output can be tensor name or tensor buffer name // LoweredFunc output can be tensor name or tensor buffer name
if (no_inline_output_names_.find(tensor->name) != no_inline_output_names_.end() || if (no_inline_output_names_.find(tensor->name) !=
no_inline_output_names_.find(tensor->buffer->name) != no_inline_output_names_.end()) { no_inline_output_names_.end() ||
no_inline_output_names_.find(tensor->buffer->name) !=
no_inline_output_names_.end()) {
return false; return false;
} }
...@@ -70,26 +77,32 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir:: ...@@ -70,26 +77,32 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::
// Check this schedule block is the only writer of the tensor. // Check this schedule block is the only writer of the tensor.
find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() && (x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name; return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name;
}); });
if (find_store.size() != 1UL) { if (find_store.size() != 1UL) {
return false; return false;
} }
// Check there is no overlap between the buffers the schedule block reads and writes. // Check there is no overlap between the buffers the schedule block reads and
std::set<ir::Expr> find_load = ir::CollectIRNodesWithoutTensor( // writes.
compute_body, [&](const Expr* x) { return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr; }); std::set<ir::Expr> find_load =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr;
});
if (!find_load.empty()) { if (!find_load.empty()) {
return false; return false;
} }
ir::Expr store = *(find_store.begin()); ir::Expr store = *(find_store.begin());
ir::ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(), store); ir::ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(),
store);
if (!inliner.BodyPatternAllowInline()) { if (!inliner.BodyPatternAllowInline()) {
return false; return false;
} }
ir::LeafBlockRemovalPlan remove_plan(sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt); ir::LeafBlockRemovalPlan remove_plan(
sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt);
remove_plan(&root); remove_plan(&root);
if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) { if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) {
return false; return false;
...@@ -99,16 +112,20 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir:: ...@@ -99,16 +112,20 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::
return true; return true;
} }
AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { AutoInlineType AutoInline::AnalyzeInlineType(
const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As<ir::ScheduleBlockRealize>(); const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); const ir::ScheduleBlockRealize* sche_block_realize =
sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
// Inline if the block has only 1 write buffer // Inline if the block has only 1 write buffer
if (sche_block->write_buffers.size() != 1) { if (sche_block->write_buffers.size() != 1) {
return AutoInlineType::kCannotInline; return AutoInlineType::kCannotInline;
} }
std::unordered_set<ir::IrNodeTy> no_inline_node_types = {ir::IrNodeTy::IfThenElse}; std::unordered_set<ir::IrNodeTy> no_inline_node_types = {
ir::IrNodeTy::IfThenElse};
if (ContainsNodeType(sche_block->body, no_inline_node_types)) { if (ContainsNodeType(sche_block->body, no_inline_node_types)) {
return AutoInlineType::kCannotInline; return AutoInlineType::kCannotInline;
} }
...@@ -125,31 +142,38 @@ AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr ...@@ -125,31 +142,38 @@ AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr
} }
RuleApplyType AutoInline::Init(ir::IRSchedule* ir_schedule) { RuleApplyType AutoInline::Init(ir::IRSchedule* ir_schedule) {
ir_schedule_ = ir_schedule; ir_schedule_ = ir_schedule;
all_block_realizes_ = ir_schedule_->GetAllBlocks(); all_block_realizes_ = ir_schedule_->GetAllBlocks();
apply_indices_and_type_.clear(); apply_indices_and_type_.clear();
num_applicable_ = 0; num_applicable_ = 0;
for (size_t i = 0; i < all_block_realizes_.size(); ++i) { for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As<ir::ScheduleBlock>()); all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AutoInlineType type = AnalyzeInlineType(all_block_realizes_[i], ir_schedule_); AnalyzeScheduleBlockReadWriteBuffer(
sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type =
AnalyzeInlineType(all_block_realizes_[i], ir_schedule_);
if (type != AutoInlineType::kCannotInline) { if (type != AutoInlineType::kCannotInline) {
++num_applicable_; ++num_applicable_;
apply_indices_and_type_.push_back({i, type}); apply_indices_and_type_.push_back({i, type});
} }
} }
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
void AutoInline::Apply(int index) { void AutoInline::Apply(int index) {
CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init"; CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init";
CHECK(num_applicable_ > 0 && apply_indices_and_type_.size() == num_applicable_) CHECK(num_applicable_ > 0 &&
apply_indices_and_type_.size() == num_applicable_)
<< "AutoInline::Apply pre-condition doesn't meet"; << "AutoInline::Apply pre-condition doesn't meet";
CHECK(index >= 0 && num_applicable_ > index) CHECK(index >= 0 && num_applicable_ > index)
<< "Invalid index for AutoInline::Apply, the index needs 0 <= index && index < NumberApplicable(), " << "Invalid index for AutoInline::Apply, the index needs 0 <= index && "
<< "Currently index = " << index << ", NumberApplicable() = " << num_applicable_; "index < NumberApplicable(), "
<< "Currently index = " << index
<< ", NumberApplicable() = " << num_applicable_;
int apply_index = apply_indices_and_type_[index].first; int apply_index = apply_indices_and_type_[index].first;
Apply(ir_schedule_, all_block_realizes_[apply_index]); Apply(ir_schedule_, all_block_realizes_[apply_index]);
...@@ -158,20 +182,25 @@ void AutoInline::Apply(int index) { ...@@ -158,20 +182,25 @@ void AutoInline::Apply(int index) {
std::string AutoInline::GetRuleName() const { return "AutoInline"; } std::string AutoInline::GetRuleName() const { return "AutoInline"; }
RuleApplyType AutoInline::AnalyseApplyType(SearchState state, const std::string& block_name) const { RuleApplyType AutoInline::AnalyseApplyType(
Expr block_expr = state->ir_schedule.GetBlock(block_name); SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>(); auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As<ir::ScheduleBlock>()); AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(block_expr, &state->ir_schedule); AutoInlineType type = AnalyzeInlineType(block_expr, &state->ir_schedule);
return type == AutoInlineType::kCannotInline ? RuleApplyType::kCannotApply : RuleApplyType::kApplyAndPruneOtherRules; return type == AutoInlineType::kCannotInline
? RuleApplyType::kCannotApply
: RuleApplyType::kApplyAndPruneOtherRules;
} }
std::vector<SearchState> AutoInline::ApplyOnBlock(SearchState state, const std::string& block_name) { std::vector<SearchState> AutoInline::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy(); SearchState new_state = state.Copy();
Expr block_expr = new_state->ir_schedule.GetBlock(block_name); Expr block_expr = new_state->ir_schedule.GetBlock(block_name);
Apply(&new_state->ir_schedule, block_expr); Apply(&new_state->ir_schedule, block_expr);
return {new_state}; return {new_state};
...@@ -181,7 +210,8 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { ...@@ -181,7 +210,8 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>(); auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As<ir::ScheduleBlock>()); AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(block_expr, ir_schedule); AutoInlineType type = AnalyzeInlineType(block_expr, ir_schedule);
if (type == AutoInlineType::kInlineIntoConsumer) { if (type == AutoInlineType::kInlineIntoConsumer) {
...@@ -202,10 +232,12 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { ...@@ -202,10 +232,12 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
// we need to re-analyze // we need to re-analyze
all_block_realizes_ = ir_schedule->GetAllBlocks(); all_block_realizes_ = ir_schedule->GetAllBlocks();
for (size_t i = 0; i < all_block_realizes_.size(); ++i) { for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
sche_block->read_buffers = {}; ir::ScheduleBlock* sche_block =
sche_block->write_buffers = {}; sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
sche_block->read_buffers = {};
sche_block->write_buffers = {};
AnalyzeScheduleBlockReadWriteBuffer(sche_block); AnalyzeScheduleBlockReadWriteBuffer(sche_block);
} }
} }
......
...@@ -41,7 +41,8 @@ enum class AutoInlineType : int { ...@@ -41,7 +41,8 @@ enum class AutoInlineType : int {
class AutoInline : public AutoGenRule { class AutoInline : public AutoGenRule {
public: public:
AutoInline(const common::Target& target, const std::unordered_set<std::string>& no_inline_output_names); AutoInline(const common::Target& target,
const std::unordered_set<std::string>& no_inline_output_names);
~AutoInline() = default; ~AutoInline() = default;
RuleApplyType Init(ir::IRSchedule* ir_schedule) override; RuleApplyType Init(ir::IRSchedule* ir_schedule) override;
...@@ -50,13 +51,17 @@ class AutoInline : public AutoGenRule { ...@@ -50,13 +51,17 @@ class AutoInline : public AutoGenRule {
std::string GetRuleName() const override; std::string GetRuleName() const override;
AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const; AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;
bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const; bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override; std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private: private:
void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr);
......
...@@ -63,7 +63,14 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -63,7 +63,14 @@ TEST(AutoInline, SingleLoopInline) {
poly::StageMap stages = CreateStages({A, B, C}); poly::StageMap stages = CreateStages({A, B, C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestAutoInline_SingleLoopInline", stages, {A, C}, {}, {}, nullptr, target, true); lang::LowerVec("TestAutoInline_SingleLoopInline",
stages,
{A, C},
{},
{},
nullptr,
target,
true);
VLOG(6) << "Expr after lowering:"; VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body; VLOG(6) << funcs[0]->body;
...@@ -74,7 +81,7 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -74,7 +81,7 @@ TEST(AutoInline, SingleLoopInline) {
*/ */
ir::IRSchedule ir_sch(ir::ModuleExpr(std::vector<ir::Expr>{funcs[0]->body})); ir::IRSchedule ir_sch(ir::ModuleExpr(std::vector<ir::Expr>{funcs[0]->body}));
SearchState state(ir_sch, 0, {}); SearchState state(ir_sch, 0, {});
ir::Expr block_b = ir_sch.GetBlock("B"); ir::Expr block_b = ir_sch.GetBlock("B");
std::vector<ir::Expr> loops = ir_sch.GetLoops("C"); std::vector<ir::Expr> loops = ir_sch.GetLoops("C");
ir_sch.ComputeAt(block_b, loops[0]); ir_sch.ComputeAt(block_b, loops[0]);
...@@ -90,12 +97,13 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -90,12 +97,13 @@ TEST(AutoInline, SingleLoopInline) {
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "B"); auto new_states = auto_inline.ApplyOnBlock(state, "B");
auto test_func = [](ir::IRSchedule* ir_sch) { auto test_func = [](ir::IRSchedule* ir_sch) {
ir::ModuleExpr mod_expr_after_inline = ir_sch->GetModule(); ir::ModuleExpr mod_expr_after_inline = ir_sch->GetModule();
std::vector<ir::Expr> exprs = mod_expr_after_inline.GetExprs(); std::vector<ir::Expr> exprs = mod_expr_after_inline.GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss; std::stringstream ss;
...@@ -130,7 +138,8 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -130,7 +138,8 @@ TEST(AutoInline, SingleLoopInline) {
// Cannot inline above expr again // Cannot inline above expr again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply); EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"), RuleApplyType::kCannotApply); EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"),
RuleApplyType::kCannotApply);
} }
TEST(AutoInline, AddReluInline) { TEST(AutoInline, AddReluInline) {
...@@ -148,15 +157,20 @@ TEST(AutoInline, AddReluInline) { ...@@ -148,15 +157,20 @@ TEST(AutoInline, AddReluInline) {
frontend::Program program = builder.Build(); frontend::Program program = builder.Build();
FLAGS_cinn_ir_schedule = true; FLAGS_cinn_ir_schedule = true;
auto graph = std::make_shared<Graph>(program, target); auto graph = std::make_shared<Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); const auto& dtype_dict =
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target); "inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
EXPECT_EQ(graph->fusion_groups.size(), 1UL); EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs = op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]); std::vector<ir::LoweredFunc> funcs =
op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]);
VLOG(6) << "Expr before auto inline: " << funcs[0]->body; VLOG(6) << "Expr before auto inline: " << funcs[0]->body;
...@@ -170,7 +184,7 @@ TEST(AutoInline, AddReluInline) { ...@@ -170,7 +184,7 @@ TEST(AutoInline, AddReluInline) {
auto_inline.Apply(1); auto_inline.Apply(1);
ir::ModuleExpr mod_expr_after_inline = ir_sch.GetModule(); ir::ModuleExpr mod_expr_after_inline = ir_sch.GetModule();
std::vector<ir::Expr> exprs = mod_expr_after_inline.GetExprs(); std::vector<ir::Expr> exprs = mod_expr_after_inline.GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss; std::stringstream ss;
...@@ -186,15 +200,17 @@ TEST(AutoInline, AddReluInline) { ...@@ -186,15 +200,17 @@ TEST(AutoInline, AddReluInline) {
auto_inline.Apply(0); auto_inline.Apply(0);
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); auto new_states = auto_inline.ApplyOnBlock(state, "var_1");
// Auto Inline again // Auto Inline again
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"),
RuleApplyType::kApplyAndPruneOtherRules);
new_states = auto_inline.ApplyOnBlock(new_states[0], "var_3"); new_states = auto_inline.ApplyOnBlock(new_states[0], "var_3");
auto test_func = [](ir::IRSchedule* ir_sch) { auto test_func = [](ir::IRSchedule* ir_sch) {
ir::ModuleExpr final_mod_expr = ir_sch->GetModule(); ir::ModuleExpr final_mod_expr = ir_sch->GetModule();
auto exprs = final_mod_expr.GetExprs(); auto exprs = final_mod_expr.GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss; std::stringstream ss;
...@@ -238,7 +254,8 @@ TEST(AutoInline, AddReluInline) { ...@@ -238,7 +254,8 @@ TEST(AutoInline, AddReluInline) {
// Cannot inline above expr again // Cannot inline above expr again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply); EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"), RuleApplyType::kCannotApply); EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"),
RuleApplyType::kCannotApply);
} }
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
...@@ -246,14 +263,8 @@ class TestAutoInline : public TestAutoGenRuleBase {}; ...@@ -246,14 +263,8 @@ class TestAutoInline : public TestAutoGenRuleBase {};
/* The single chain graph composed of multiple blocks can be inlined into one. /* The single chain graph composed of multiple blocks can be inlined into one.
* *
* Before AutoInline: The output of the previous block is the input of another block. * Before AutoInline: The output of the previous block is the input of another
* Loop1: * block. Loop1: x1 = Add() Loop2: x2 = Multiply(x1) Loop3: x3 = Add(x2) Loop4:
* x1 = Add()
* Loop2:
* x2 = Multiply(x1)
* Loop3:
* x3 = Add(x2)
* Loop4:
* x4 = Relu(x3) * x4 = Relu(x3)
* *
* After AutoInline: All loops are inlined into a loop. * After AutoInline: All loops are inlined into a loop.
...@@ -263,18 +274,22 @@ class TestAutoInline : public TestAutoGenRuleBase {}; ...@@ -263,18 +274,22 @@ class TestAutoInline : public TestAutoGenRuleBase {};
TEST_F(TestAutoInline, SingleChain) { TEST_F(TestAutoInline, SingleChain) {
Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultNVGPUTarget();
Initialize(target); Initialize(target);
std::vector<std::string> input_names = {"bias", "conv_output", "bn_scale", "bn_offset"}; std::vector<std::string> input_names = {
std::vector<std::string> output_names = {"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"}; "bias", "conv_output", "bn_scale", "bn_offset"};
std::vector<std::string> output_names = {
"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"};
std::vector<int32_t> conv_output_shape = {1, 512, 56, 56}; std::vector<int32_t> conv_output_shape = {1, 512, 56, 56};
int32_t channel = conv_output_shape[1]; int32_t channel = conv_output_shape[1];
std::vector<tests::VariableInfo> inputs_varinfo({{"conv_output", conv_output_shape}, std::vector<tests::VariableInfo> inputs_varinfo(
{"bias", {channel, 1, 1}}, {{"conv_output", conv_output_shape},
{"bn_scale", {channel, 1, 1}}, {"bias", {channel, 1, 1}},
{"bn_offset", {channel, 1, 1}}}); {"bn_scale", {channel, 1, 1}},
{"bn_offset", {channel, 1, 1}}});
// Construct the computation graph and convert it to ir::Expr // Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId(); Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo)); ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
...@@ -282,20 +297,23 @@ TEST_F(TestAutoInline, SingleChain) { ...@@ -282,20 +297,23 @@ TEST_F(TestAutoInline, SingleChain) {
// Apply AutoInline for every block that can be inline // Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()}); AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_3"); auto new_states = auto_inline.ApplyOnBlock(state, "var_3");
std::vector<std::string> inline_block_names({"var_4", "var_5", "var_6", "var", "var_0", "var_1"}); std::vector<std::string> inline_block_names(
{"var_4", "var_5", "var_6", "var", "var_0", "var_1"});
for (const auto& inline_block_name : inline_block_names) { for (const auto& inline_block_name : inline_block_names) {
new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name); new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name);
} }
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = auto build_module_manually = BuildIRModule(MakeIRSchedule(
BuildIRModule(MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true)); tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto); auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto; VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually); auto source_code_manually = GenSourceCode(build_module_manually);
...@@ -305,7 +323,10 @@ TEST_F(TestAutoInline, SingleChain) { ...@@ -305,7 +323,10 @@ TEST_F(TestAutoInline, SingleChain) {
GenExecutableKernel(build_module_manually), GenExecutableKernel(build_module_manually),
input_names, input_names,
output_names, output_names,
{{conv_output_shape[1], 1, 1}, conv_output_shape, conv_output_shape, conv_output_shape}, {{conv_output_shape[1], 1, 1},
conv_output_shape,
conv_output_shape,
conv_output_shape},
{conv_output_shape, {1}, {1}, {1}, {1}, {1}, {1}}, {conv_output_shape, {1}, {1}, {1}, {1}, {1}, {1}},
target); target);
} }
...@@ -328,14 +349,15 @@ TEST_F(TestAutoInline, SingleChain) { ...@@ -328,14 +349,15 @@ TEST_F(TestAutoInline, SingleChain) {
TEST_F(TestAutoInline, InlineToMultiConsumers) { TEST_F(TestAutoInline, InlineToMultiConsumers) {
Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultNVGPUTarget();
Initialize(target); Initialize(target);
std::vector<std::string> input_names = {"x"}; std::vector<std::string> input_names = {"x"};
std::vector<std::string> output_names = {"var_2", "var_1", "var_0"}; std::vector<std::string> output_names = {"var_2", "var_1", "var_0"};
std::vector<int32_t> input_shape{256, 256}; std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}}); std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}});
// Construct the computation graph and convert it to ir::Expr // Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId(); Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo)); ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
...@@ -343,17 +365,19 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) { ...@@ -343,17 +365,19 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) {
// Apply AutoInline for every block that can be inline // Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()}); AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"),
auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); RuleApplyType::kApplyAndPruneOtherRules);
new_states = auto_inline.ApplyOnBlock(state, "var_0"); auto new_states = auto_inline.ApplyOnBlock(state, "var_1");
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); new_states = auto_inline.ApplyOnBlock(state, "var_0");
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = auto build_module_manually = BuildIRModule(MakeIRSchedule(
BuildIRModule(MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true)); tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto); auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto; VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually); auto source_code_manually = GenSourceCode(build_module_manually);
...@@ -386,15 +410,21 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) { ...@@ -386,15 +410,21 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) {
TEST_F(TestAutoInline, OnlySpatialOp) { TEST_F(TestAutoInline, OnlySpatialOp) {
Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultNVGPUTarget();
Initialize(target); Initialize(target);
std::vector<std::string> input_names = {"x", "y"}; std::vector<std::string> input_names = {"x", "y"};
std::vector<std::string> output_names = { std::vector<std::string> output_names = {"var_6",
"var_6", "var_4", "constant_idx_last", "constant_idx_first", "var_2", "var_5"}; "var_4",
"constant_idx_last",
"constant_idx_first",
"var_2",
"var_5"};
std::vector<int32_t> input_shape{256, 256}; std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}, {"y", input_shape}}); std::vector<tests::VariableInfo> inputs_varinfo(
{{"x", input_shape}, {"y", input_shape}});
// Construct the computation graph and convert it to ir::Expr // Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId(); Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo)); ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
...@@ -402,20 +432,23 @@ TEST_F(TestAutoInline, OnlySpatialOp) { ...@@ -402,20 +432,23 @@ TEST_F(TestAutoInline, OnlySpatialOp) {
// Apply AutoInline for every block that can be inline // Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()}); AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "constant_idx_first"); auto new_states = auto_inline.ApplyOnBlock(state, "constant_idx_first");
std::vector<std::string> inline_block_names({"constant_idx_last", "var_2", "var_5", "var_4"}); std::vector<std::string> inline_block_names(
{"constant_idx_last", "var_2", "var_5", "var_4"});
for (const auto& inline_block_name : inline_block_names) { for (const auto& inline_block_name : inline_block_names) {
new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name); new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name);
} }
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = auto build_module_manually = BuildIRModule(MakeIRSchedule(
BuildIRModule(MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true)); tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto); auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto; VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually); auto source_code_manually = GenSourceCode(build_module_manually);
...@@ -445,13 +478,14 @@ TEST_F(TestAutoInline, OnlySpatialOp) { ...@@ -445,13 +478,14 @@ TEST_F(TestAutoInline, OnlySpatialOp) {
TEST_F(TestAutoInline, NoReadBufferOp) { TEST_F(TestAutoInline, NoReadBufferOp) {
Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultNVGPUTarget();
Initialize(target); Initialize(target);
std::vector<std::string> input_names = {"x"}; std::vector<std::string> input_names = {"x"};
std::vector<std::string> output_names = {"var_0", "fill_constant"}; std::vector<std::string> output_names = {"var_0", "fill_constant"};
std::vector<int32_t> input_shape{256, 256}; std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}}); std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}});
// Construct the computation graph and convert it to ir::Expr // Construct the computation graph and convert it to ir::Expr
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo)); ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
...@@ -459,16 +493,18 @@ TEST_F(TestAutoInline, NoReadBufferOp) { ...@@ -459,16 +493,18 @@ TEST_F(TestAutoInline, NoReadBufferOp) {
// Apply AutoInline for every block that can be inline // Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()}); AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"),
auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant"); RuleApplyType::kApplyAndPruneOtherRules);
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant");
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = auto build_module_manually = BuildIRModule(MakeIRSchedule(
BuildIRModule(MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true)); tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto); auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto; VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually); auto source_code_manually = GenSourceCode(build_module_manually);
......
...@@ -33,11 +33,13 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { ...@@ -33,11 +33,13 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
auto has_reduce_iter = [](const Expr* x) { auto has_reduce_iter = [](const Expr* x) {
auto* block_realize = x->As<ir::ScheduleBlockRealize>(); auto* block_realize = x->As<ir::ScheduleBlockRealize>();
if (block_realize) { if (block_realize) {
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>(); auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock"; CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
for (auto&& var : schedule_block->iter_vars) { for (auto&& var : schedule_block->iter_vars) {
if (var->is_reduce_axis) { if (var->is_reduce_axis) {
VLOG(6) << "find ScheduleBlockRealize:" << *x << " has reduce_axis:" << var; VLOG(6) << "find ScheduleBlockRealize:" << *x
<< " has reduce_axis:" << var;
return true; return true;
} }
} }
...@@ -46,7 +48,8 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { ...@@ -46,7 +48,8 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
}; };
// whether has any for-loop with non-serial type // whether has any for-loop with non-serial type
auto has_nonserial_loop = [](const Expr* x) { auto has_nonserial_loop = [](const Expr* x) {
if (x->As<ir::For>() && x->As<ir::For>()->for_type() != ir::ForType::Serial) { if (x->As<ir::For>() &&
x->As<ir::For>()->for_type() != ir::ForType::Serial) {
VLOG(6) << "find non-serial loop:" << *x; VLOG(6) << "find non-serial loop:" << *x;
return true; return true;
} }
...@@ -55,13 +58,15 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { ...@@ -55,13 +58,15 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
auto find_target_exprs = ir::CollectIRNodesWithoutTensor( auto find_target_exprs = ir::CollectIRNodesWithoutTensor(
schedule_block->body, schedule_block->body,
[&has_reduce_iter, &has_nonserial_loop](const Expr* x) { return has_reduce_iter(x) || has_nonserial_loop(x); }); [&has_reduce_iter, &has_nonserial_loop](const Expr* x) {
return has_reduce_iter(x) || has_nonserial_loop(x);
});
return !find_target_exprs.empty(); return !find_target_exprs.empty();
} }
RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) { RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) {
ir_schedule_ = ir_schedule; ir_schedule_ = ir_schedule;
auto block_realizes = ir_schedule_->GetAllBlocks(); auto block_realizes = ir_schedule_->GetAllBlocks();
// A schedule block can perform `auto_unroll` rule should meet two conditions: // A schedule block can perform `auto_unroll` rule should meet two conditions:
...@@ -71,47 +76,58 @@ RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) { ...@@ -71,47 +76,58 @@ RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) {
std::set<Expr> deduplicate_results; std::set<Expr> deduplicate_results;
for (size_t i = 0; i < block_realizes.size(); ++i) { for (size_t i = 0; i < block_realizes.size(); ++i) {
// find root block // find root block
Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]); Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]);
auto* block_realize = root_block.As<ir::ScheduleBlockRealize>(); auto* block_realize = root_block.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block;
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>(); auto* schedule_block =
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize); block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:"
<< Expr(block_realize);
if (MeetCondition(schedule_block)) { if (MeetCondition(schedule_block)) {
deduplicate_results.emplace(root_block); deduplicate_results.emplace(root_block);
} }
} }
applicable_schedule_blocks_ = {deduplicate_results.begin(), deduplicate_results.end()}; applicable_schedule_blocks_ = {deduplicate_results.begin(),
num_applicable_ = applicable_schedule_blocks_.size(); deduplicate_results.end()};
num_applicable_ = applicable_schedule_blocks_.size();
VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_; VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
void AutoUnroll::Apply(int index) { void AutoUnroll::Apply(int index) {
CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index; CHECK_LT(index, applicable_schedule_blocks_.size())
<< "invalid apply index:" << index;
auto applied_block = applicable_schedule_blocks_.at(index); auto applied_block = applicable_schedule_blocks_.at(index);
int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()];
ir_schedule_->Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step); ir_schedule_->Annotate(
applied_block, ir::attr::auto_unroll_max_step, max_step);
return; return;
} }
RuleApplyType AutoUnroll::AnalyseApplyType(SearchState state, const std::string& block_name) const { RuleApplyType AutoUnroll::AnalyseApplyType(
Expr block_expr = state->ir_schedule.GetBlock(block_name); SearchState state, const std::string& block_name) const {
Expr root_block = state->ir_schedule.GetRootBlock(block_expr); Expr block_expr = state->ir_schedule.GetBlock(block_name);
Expr root_block = state->ir_schedule.GetRootBlock(block_expr);
auto* block_realize = root_block.As<ir::ScheduleBlockRealize>(); auto* block_realize = root_block.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block;
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>(); auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize); CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:"
<< Expr(block_realize);
return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
std::vector<SearchState> AutoUnroll::ApplyOnBlock(SearchState state, const std::string& block_name) { std::vector<SearchState> AutoUnroll::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy(); SearchState new_state = state.Copy();
Expr block_expr = new_state->ir_schedule.GetBlock(block_name); Expr block_expr = new_state->ir_schedule.GetBlock(block_name);
Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr); Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr);
int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()];
new_state->ir_schedule.Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step); new_state->ir_schedule.Annotate(
applied_block, ir::attr::auto_unroll_max_step, max_step);
return {new_state}; return {new_state};
} }
......
...@@ -24,10 +24,11 @@ ...@@ -24,10 +24,11 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
// This rule can be applied in a ScheduleBlock has reduce axis or has loops with non-serial type. // This rule can be applied in a ScheduleBlock has reduce axis or has loops with
// As a result, it will set a attribute with key named ir::attr::auto_unroll_max_step and value // non-serial type. As a result, it will set a attribute with key named
// indicating max permitted unrolled step in the applied ScheduleBlock. Finally, UnrollLoop pass // ir::attr::auto_unroll_max_step and value indicating max permitted unrolled
// will do unroll based on actual situation. // step in the applied ScheduleBlock. Finally, UnrollLoop pass will do unroll
// based on actual situation.
class AutoUnroll : public AutoGenRule { class AutoUnroll : public AutoGenRule {
public: public:
AutoUnroll(const common::Target& target) : AutoGenRule(target) {} AutoUnroll(const common::Target& target) : AutoGenRule(target) {}
...@@ -39,9 +40,11 @@ class AutoUnroll : public AutoGenRule { ...@@ -39,9 +40,11 @@ class AutoUnroll : public AutoGenRule {
std::string GetRuleName() const override { return "AutoUnroll"; } std::string GetRuleName() const override { return "AutoUnroll"; }
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override; std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private: private:
bool MeetCondition(const ir::ScheduleBlock* schedule_block) const; bool MeetCondition(const ir::ScheduleBlock* schedule_block) const;
......
...@@ -39,7 +39,8 @@ TEST(AutoUnroll, Init) { ...@@ -39,7 +39,8 @@ TEST(AutoUnroll, Init) {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
auto stages = CreateStages({C}); auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec("test_init", stages, {A, B, C}, {}, {}, nullptr, target, true); auto funcs = cinn::lang::LowerVec(
"test_init", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto ast_expr = funcs[0]->body; auto ast_expr = funcs[0]->body;
ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr})); ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr}));
...@@ -58,7 +59,9 @@ TEST(AutoUnroll, UnrollableApply) { ...@@ -58,7 +59,9 @@ TEST(AutoUnroll, UnrollableApply) {
Placeholder<float> B("B", {K, N}); Placeholder<float> B("B", {K, N});
Var k(K.as_int32(), "k0"); Var k(K.as_int32(), "k0");
Tensor C = Compute( Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); {M, N},
[&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultNVGPUTarget();
...@@ -66,11 +69,14 @@ TEST(AutoUnroll, UnrollableApply) { ...@@ -66,11 +69,14 @@ TEST(AutoUnroll, UnrollableApply) {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
auto stages = CreateStages({C}); auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec("test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true); auto funcs = cinn::lang::LowerVec(
"test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto ast_expr = funcs[0]->body; auto ast_expr = funcs[0]->body;
auto* init_block_realize = ast_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>(); auto* init_block_realize =
auto* init_schedule_block = init_block_realize->schedule_block.As<ir::ScheduleBlock>(); ast_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>();
auto* init_schedule_block =
init_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_NE(init_schedule_block, nullptr); ASSERT_NE(init_schedule_block, nullptr);
ASSERT_TRUE(init_schedule_block->attrs.empty()); ASSERT_TRUE(init_schedule_block->attrs.empty());
VLOG(6) << "Before auto-unroll:\n" << ast_expr; VLOG(6) << "Before auto-unroll:\n" << ast_expr;
...@@ -78,25 +84,34 @@ TEST(AutoUnroll, UnrollableApply) { ...@@ -78,25 +84,34 @@ TEST(AutoUnroll, UnrollableApply) {
AutoUnroll test_rule(target); AutoUnroll test_rule(target);
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
ASSERT_EQ(test_rule.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); ASSERT_EQ(test_rule.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(test_rule.NumberApplicable(), 1); EXPECT_EQ(test_rule.NumberApplicable(), 1);
test_rule.ApplyRandomly(); test_rule.ApplyRandomly();
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"),
std::vector<cinn::auto_schedule::SearchState> states = test_rule.ApplyOnBlock(state, "C"); RuleApplyType::kApplyAndPruneOtherRules);
std::vector<cinn::auto_schedule::SearchState> states =
test_rule.ApplyOnBlock(state, "C");
auto test_func = [](IRSchedule* ir_sch) { auto test_func = [](IRSchedule* ir_sch) {
Expr applied_expr = ir_sch->GetModule().GetExprs().front(); Expr applied_expr = ir_sch->GetModule().GetExprs().front();
auto* applied_block_realize = applied_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>(); auto* applied_block_realize = applied_expr.As<ir::Block>()
auto* applied_schedule_block = applied_block_realize->schedule_block.As<ir::ScheduleBlock>(); ->stmts.front()
.As<ir::ScheduleBlockRealize>();
auto* applied_schedule_block =
applied_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_FALSE(applied_schedule_block->attrs.empty()); ASSERT_FALSE(applied_schedule_block->attrs.empty());
EXPECT_EQ(applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1); EXPECT_EQ(
const auto& attr_value = applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step); applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1);
const int* max_step = absl::get_if<int>(&attr_value); const auto& attr_value =
applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step);
const int* max_step = absl::get_if<int>(&attr_value);
EXPECT_NE(max_step, nullptr); EXPECT_NE(max_step, nullptr);
EXPECT_LE(*max_step, 128); EXPECT_LE(*max_step, 128);
VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front(); VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n"
<< ir_sch->GetModule().GetExprs().front();
}; };
test_func(&ir_schedule); test_func(&ir_schedule);
......
...@@ -29,32 +29,35 @@ namespace auto_schedule { ...@@ -29,32 +29,35 @@ namespace auto_schedule {
class TestMixRules : public TestAutoGenRuleBase { class TestMixRules : public TestAutoGenRuleBase {
public: public:
std::vector<std::string> default_input_names = {"X", "Y"}; std::vector<std::string> default_input_names = {"X", "Y"};
std::vector<std::string> default_output_names = {"temp_matmul_out"}; std::vector<std::string> default_output_names = {"temp_matmul_out"};
}; };
TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) { TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) {
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}); frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}});
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op); ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op);
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0]; VLOG(6) << "Original Expr:\n" << func_bodys[0];
// Apply MultiLevelTiling // Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
multi_level_tiling.Init(&ir_schedule); multi_level_tiling.Init(&ir_schedule);
ASSERT_EQ(multi_level_tiling.NumberApplicable(), 1); ASSERT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly(); multi_level_tiling.ApplyRandomly();
VLOG(6) << "after MultiLevelTiling Expr:\n" << func_bodys[0]; VLOG(6) << "after MultiLevelTiling Expr:\n" << func_bodys[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto ir_module = BuildIRModule(ir_schedule); auto ir_module = BuildIRModule(ir_schedule);
auto source_code = GenSourceCode(ir_module); auto source_code = GenSourceCode(ir_module);
VLOG(6) << "scheduled source code:\n" << source_code; VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision // execute and check precision
CheckResult(GenExecutableKernel(ir_module), CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))), GenExecutableKernel(BuildIRModule(
MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))),
default_input_names, default_input_names,
default_output_names, default_output_names,
{{32, 32}, {32, 32}}, {{32, 32}, {32, 32}},
......
...@@ -72,9 +72,11 @@ class MultiLevelTiling : public AutoGenRule { ...@@ -72,9 +72,11 @@ class MultiLevelTiling : public AutoGenRule {
// Returns true if sche_block_realize is applicable by MultiLevelTiling // Returns true if sche_block_realize is applicable by MultiLevelTiling
bool MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const; bool MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override; std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
// Sample pair of integer type (a, b) such as a * b = extent // Sample pair of integer type (a, b) such as a * b = extent
template <typename T> template <typename T>
...@@ -88,10 +90,10 @@ class MultiLevelTiling : public AutoGenRule { ...@@ -88,10 +90,10 @@ class MultiLevelTiling : public AutoGenRule {
if (candidates.size() == 0) { if (candidates.size() == 0) {
return {1, T(extent)}; return {1, T(extent)};
} }
int index = rand() % candidates.size(); int index = rand() % candidates.size();
std::vector<T> pick = candidates[index]; std::vector<T> pick = candidates[index];
if (rand() % 2 != 0) { if (rand() % 2 != 0) {
T tmp = pick[0]; T tmp = pick[0];
pick[0] = pick[1]; pick[0] = pick[1];
pick[1] = tmp; pick[1] = tmp;
} }
...@@ -101,7 +103,8 @@ class MultiLevelTiling : public AutoGenRule { ...@@ -101,7 +103,8 @@ class MultiLevelTiling : public AutoGenRule {
// Sample num_split integers whose product equals extent // Sample num_split integers whose product equals extent
template <typename T> template <typename T>
std::vector<T> SampleTileSplit(T extent, int num_split) const { std::vector<T> SampleTileSplit(T extent, int num_split) const {
CHECK_GT(num_split, 0) << "num_split in SampleTileSplit must be greater than 0"; CHECK_GT(num_split, 0)
<< "num_split in SampleTileSplit must be greater than 0";
if (num_split == 1) { if (num_split == 1) {
return {extent}; return {extent};
} }
...@@ -109,7 +112,7 @@ class MultiLevelTiling : public AutoGenRule { ...@@ -109,7 +112,7 @@ class MultiLevelTiling : public AutoGenRule {
if (num_split == 2) { if (num_split == 2) {
return two_split; return two_split;
} }
int half = num_split >> 1; int half = num_split >> 1;
std::vector<T> result = SampleTileSplit<T>(two_split[0], half); std::vector<T> result = SampleTileSplit<T>(two_split[0], half);
std::vector<T> remind = SampleTileSplit<T>(two_split[1], num_split - half); std::vector<T> remind = SampleTileSplit<T>(two_split[1], num_split - half);
result.insert(result.end(), remind.begin(), remind.end()); result.insert(result.end(), remind.begin(), remind.end());
......
...@@ -48,11 +48,13 @@ TEST(MultiLevelTile, SampleSplitTwo) { ...@@ -48,11 +48,13 @@ TEST(MultiLevelTile, SampleSplitTwo) {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
size_t number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] size_t number_to_split = rand() % 65535 + 2; // random number in [2, 2^16]
std::vector<size_t> split = multi_level_tiling.SampleSplitTwo<size_t>(number_to_split); std::vector<size_t> split =
multi_level_tiling.SampleSplitTwo<size_t>(number_to_split);
EXPECT_EQ(split.size(), 2UL); EXPECT_EQ(split.size(), 2UL);
EXPECT_EQ(split[0] * split[1], number_to_split); EXPECT_EQ(split[0] * split[1], number_to_split);
} }
...@@ -67,12 +69,14 @@ TEST(MultiLevelTile, SampleTileSplit) { ...@@ -67,12 +69,14 @@ TEST(MultiLevelTile, SampleTileSplit) {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
int number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] int number_to_split = rand() % 65535 + 2; // random number in [2, 2^16]
int split_size = rand() % 5 + 1; // random in [1, 5] int split_size = rand() % 5 + 1; // random in [1, 5]
std::vector<int> split = multi_level_tiling.SampleTileSplit<int>(number_to_split, split_size); std::vector<int> split =
multi_level_tiling.SampleTileSplit<int>(number_to_split, split_size);
EXPECT_EQ(split.size(), static_cast<size_t>(split_size)); EXPECT_EQ(split.size(), static_cast<size_t>(split_size));
int product = 1; int product = 1;
for (int num : split) { for (int num : split) {
...@@ -102,21 +106,31 @@ TEST(MultiLevelTile, SimpleLoops) { ...@@ -102,21 +106,31 @@ TEST(MultiLevelTile, SimpleLoops) {
poly::StageMap stages = CreateStages({C}); poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_SimpleLoops", stages, {C}, {}, {}, nullptr, target, true); lang::LowerVec("TestMultiLevelTile_SimpleLoops",
stages,
{C},
{},
{},
nullptr,
target,
true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr; VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1); EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly(); multi_level_tiling.ApplyRandomly();
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, "C"); auto new_states = multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) { auto test_func = [](ir::IRSchedule* ir_sch) {
...@@ -152,26 +166,30 @@ TEST(MulitLevelTile, MatrixMultiply) { ...@@ -152,26 +166,30 @@ TEST(MulitLevelTile, MatrixMultiply) {
Var k(K.as_int32(), "reduce_axis_k"); Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({C}); poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {},
nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr; VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); MultiLevelTiling multi_level_tiling(target,
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); MultiLevelTiling::kConfigs.at(target.arch)); ir::IRSchedule
SearchState state(ir_schedule, 0, {}); ir_schedule(ir::ModuleExpr({ast_expr})); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1); EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly(); multi_level_tiling.ApplyRandomly();
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
auto new_states = multi_level_tiling.ApplyOnBlock(state, "C"); RuleApplyType::kApplyAndPruneOtherRules); auto new_states =
multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) { auto test_func = [](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs(); std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
...@@ -194,25 +212,28 @@ class TestMultiLevelTiling : public TestAutoGenRuleBase { ...@@ -194,25 +212,28 @@ class TestMultiLevelTiling : public TestAutoGenRuleBase {
}; };
TEST_F(TestMultiLevelTiling, Matmul) { TEST_F(TestMultiLevelTiling, Matmul) {
default_input_names = {"X", "Y"}; default_input_names = {"X", "Y"};
default_output_names = {"temp_matmul_out"}; default_output_names = {"temp_matmul_out"};
std::vector<int32_t> X_shape = {32, 32}; std::vector<int32_t> X_shape = {32, 32};
std::vector<int32_t> Y_shape = {32, 32}; std::vector<int32_t> Y_shape = {32, 32};
std::vector<int32_t> out_shape = {32, 32}; std::vector<int32_t> out_shape = {32, 32};
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); frontend::Program matmul_op =
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed);
SearchState state(ir_schedule); SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString(); VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling // Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]),
RuleApplyType::kApplyAndPruneOtherRules); RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); auto new_states =
multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString(); VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString();
std::string ir = GetIR(new_states[0]->ir_schedule); std::string ir = GetIR(new_states[0]->ir_schedule);
std::string expected_ir = R"ROC(Expr 0 { std::string expected_ir = R"ROC(Expr 0 {
{ {
ScheduleBlock(root) ScheduleBlock(root)
...@@ -325,14 +346,15 @@ TEST_F(TestMultiLevelTiling, Matmul) { ...@@ -325,14 +346,15 @@ TEST_F(TestMultiLevelTiling, Matmul) {
ASSERT_EQ(ir, expected_ir); ASSERT_EQ(ir, expected_ir);
// build ir::Module and debug source code // build ir::Module and debug source code
auto ir_module = BuildIRModule(new_states[0]->ir_schedule); auto ir_module = BuildIRModule(new_states[0]->ir_schedule);
auto source_code = GenSourceCode(ir_module); auto source_code = GenSourceCode(ir_module);
VLOG(6) << "scheduled source code:\n" << source_code; VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision // execute and check precision
CheckResult( CheckResult(
GenExecutableKernel(ir_module), GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), GenExecutableKernel(BuildIRModule(MakeIRSchedule(
matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names, default_input_names,
default_output_names, default_output_names,
{X_shape, Y_shape}, {X_shape, Y_shape},
...@@ -341,26 +363,29 @@ TEST_F(TestMultiLevelTiling, Matmul) { ...@@ -341,26 +363,29 @@ TEST_F(TestMultiLevelTiling, Matmul) {
} }
TEST_F(TestMultiLevelTiling, ReduceSum) { TEST_F(TestMultiLevelTiling, ReduceSum) {
default_input_names = {"X"}; default_input_names = {"X"};
default_output_names = {"var_0_tmp"}; default_output_names = {"var_0_tmp"};
std::vector<int32_t> X_shape = {1, 16, 32}; std::vector<int32_t> X_shape = {1, 16, 32};
std::vector<int32_t> out_shape = {1, 16, 1}; std::vector<int32_t> out_shape = {1, 16, 1};
std::vector<int32_t> reduce_dim = {2}; std::vector<int32_t> reduce_dim = {2};
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
frontend::Program reduce_sum_op = frontend::Program reduce_sum_op =
tests::OpBuilder("reduce_sum").Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}}); tests::OpBuilder("reduce_sum")
.Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}});
ir::IRSchedule ir_schedule = MakeIRSchedule(reduce_sum_op); ir::IRSchedule ir_schedule = MakeIRSchedule(reduce_sum_op);
SearchState state(ir_schedule); SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString(); VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling // Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); MultiLevelTiling multi_level_tiling(
// EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), RuleApplyType::kCannotApply); target_, MultiLevelTiling::kConfigs.at(target_.arch));
// EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state,
// default_output_names[0]), RuleApplyType::kCannotApply);
} }
TEST_F(TestMultiLevelTiling, Pool2d) { TEST_F(TestMultiLevelTiling, Pool2d) {
default_input_names = {"input"}; default_input_names = {"input"};
default_output_names = {"var_0"}; default_output_names = {"var_0"};
std::vector<int32_t> input_shape{2, 8, 16, 16}; std::vector<int32_t> input_shape{2, 8, 16, 16};
std::vector<int32_t> output_shape{2, 8, 8, 8}; std::vector<int32_t> output_shape{2, 8, 8, 8};
...@@ -368,23 +393,24 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -368,23 +393,24 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
std::vector<int> ksize{3, 3}; std::vector<int> ksize{3, 3};
std::vector<int> strides{2, 2}; std::vector<int> strides{2, 2};
std::vector<int> paddings{1, 1, 1, 1}; std::vector<int> paddings{1, 1, 1, 1};
bool ceil_mode = false; bool ceil_mode = false;
bool exclusive = true; bool exclusive = true;
bool global_pooling = false; bool global_pooling = false;
std::string data_format = "NCHW"; std::string data_format = "NCHW";
bool adaptive = false; bool adaptive = false;
std::string padding_algorithm = "EXPLICIT"; std::string padding_algorithm = "EXPLICIT";
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shape}}, frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build(
{{"pool_type", pooling_type}, {{"input", input_shape}},
{"kernel_size", ksize}, {{"pool_type", pooling_type},
{"stride_size", strides}, {"kernel_size", ksize},
{"padding_size", paddings}, {"stride_size", strides},
{"ceil_mode", ceil_mode}, {"padding_size", paddings},
{"exclusive", exclusive}, {"ceil_mode", ceil_mode},
{"global_pooling", global_pooling}, {"exclusive", exclusive},
{"data_format", data_format}, {"global_pooling", global_pooling},
{"adaptive", adaptive}, {"data_format", data_format},
{"padding_algorithm", padding_algorithm}}); {"adaptive", adaptive},
{"padding_algorithm", padding_algorithm}});
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(pool2d_program, fixed_rand_seed); ir::IRSchedule ir_schedule = MakeIRSchedule(pool2d_program, fixed_rand_seed);
...@@ -403,10 +429,11 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -403,10 +429,11 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
MultiLevelTiling multi_level_tiling(target_, mlt_config); MultiLevelTiling multi_level_tiling(target_, mlt_config);
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]),
RuleApplyType::kApplyAndPruneOtherRules); RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); auto new_states =
multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString(); VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString();
std::string ir = GetIR(new_states[0]->ir_schedule); std::string ir = GetIR(new_states[0]->ir_schedule);
std::string expected_ir = R"ROC(Expr 0 { std::string expected_ir = R"ROC(Expr 0 {
{ {
ScheduleBlock(root) ScheduleBlock(root)
...@@ -529,19 +556,20 @@ Expr 1 { ...@@ -529,19 +556,20 @@ Expr 1 {
ASSERT_EQ(ir, expected_ir); ASSERT_EQ(ir, expected_ir);
// build ir::Module and debug source code // build ir::Module and debug source code
auto ir_module = BuildIRModule(new_states[0]->ir_schedule); auto ir_module = BuildIRModule(new_states[0]->ir_schedule);
auto source_code = GenSourceCode(ir_module); auto source_code = GenSourceCode(ir_module);
VLOG(6) << "scheduled source code:\n" << source_code; VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision // execute and check precision
CheckResult(GenExecutableKernel(ir_module), CheckResult(
GenExecutableKernel( GenExecutableKernel(ir_module),
BuildIRModule(MakeIRSchedule(pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))), GenExecutableKernel(BuildIRModule(MakeIRSchedule(
default_input_names, pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_output_names, default_input_names,
{input_shape}, default_output_names,
{output_shape}, {input_shape},
target_); {output_shape},
target_);
} }
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -27,7 +27,7 @@ namespace auto_schedule { ...@@ -27,7 +27,7 @@ namespace auto_schedule {
SkipRule::SkipRule(const common::Target& target) : AutoGenRule(target) {} SkipRule::SkipRule(const common::Target& target) : AutoGenRule(target) {}
RuleApplyType SkipRule::Init(ir::IRSchedule* ir_schedule) { RuleApplyType SkipRule::Init(ir::IRSchedule* ir_schedule) {
ir_schedule_ = ir_schedule; ir_schedule_ = ir_schedule;
num_applicable_ = 1; num_applicable_ = 1;
return RuleApplyType::kApply; return RuleApplyType::kApply;
} }
......
...@@ -34,11 +34,15 @@ class SkipRule : public AutoGenRule { ...@@ -34,11 +34,15 @@ class SkipRule : public AutoGenRule {
std::string GetRuleName() const override; std::string GetRuleName() const override;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override { RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override {
return RuleApplyType::kApply; return RuleApplyType::kApply;
} }
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override { return {state}; } std::vector<SearchState> ApplyOnBlock(
SearchState state, const std::string& block_name) override {
return {state};
}
}; };
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -52,8 +52,9 @@ TEST(SkipRule, Basic) { ...@@ -52,8 +52,9 @@ TEST(SkipRule, Basic) {
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C}); poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: "; VLOG(6) << "Expr before SkipRule: ";
...@@ -69,7 +70,8 @@ TEST(SkipRule, Basic) { ...@@ -69,7 +70,8 @@ TEST(SkipRule, Basic) {
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply); EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply);
std::vector<cinn::auto_schedule::SearchState> states = skip_rule.ApplyOnBlock(state, "C"); std::vector<cinn::auto_schedule::SearchState> states =
skip_rule.ApplyOnBlock(state, "C");
auto test_func = [&ast_expr](ir::IRSchedule* ir_sch) { auto test_func = [&ast_expr](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs(); std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
...@@ -99,8 +101,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) { ...@@ -99,8 +101,9 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C}); poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: "; VLOG(6) << "Expr before SkipRule: ";
...@@ -111,7 +114,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) { ...@@ -111,7 +114,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply); EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply);
std::vector<cinn::auto_schedule::SearchState> states = skip_rule.ApplyOnBlock(state, "C"); std::vector<cinn::auto_schedule::SearchState> states =
skip_rule.ApplyOnBlock(state, "C");
std::vector<ir::Expr> exprs = states[0]->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs = states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
......
...@@ -42,26 +42,32 @@ using ::cinn::hlir::framework::Shape; ...@@ -42,26 +42,32 @@ using ::cinn::hlir::framework::Shape;
using ::cinn::hlir::framework::Tensor; using ::cinn::hlir::framework::Tensor;
void TestAutoGenRuleBase::Initialize(const common::Target& target) { void TestAutoGenRuleBase::Initialize(const common::Target& target) {
target_ = target; target_ = target;
backend_compier_ = backends::Compiler::Create(target); backend_compier_ = backends::Compiler::Create(target);
} }
ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(const frontend::Program& test_program, ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
utils::LinearRandomEngine::StateType rand_seed, const frontend::Program& test_program,
bool apply_manual_schedule) { utils::LinearRandomEngine::StateType rand_seed,
bool apply_manual_schedule) {
Context::Global().ResetNameId(); Context::Global().ResetNameId();
auto graph = std::make_shared<hlir::framework::Graph>(test_program, target_); auto graph = std::make_shared<hlir::framework::Graph>(test_program, target_);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
LOG_IF(WARNING, graph->fusion_groups.size() > 1) << "Test Graph has more than 1 group"; LOG_IF(WARNING, graph->fusion_groups.size() > 1)
auto& dtype_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); << "Test Graph has more than 1 group";
auto& shape_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); auto& dtype_dict =
graph->GetMutableAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_); hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);
if (apply_manual_schedule) { if (apply_manual_schedule) {
lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front()); lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front());
} else { } else {
lowered_funcs_ = op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front()); lowered_funcs_ =
op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front());
} }
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty"; CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";
...@@ -76,20 +82,22 @@ std::string TestAutoGenRuleBase::GetIR(const ir::IRSchedule& schedule) { ...@@ -76,20 +82,22 @@ std::string TestAutoGenRuleBase::GetIR(const ir::IRSchedule& schedule) {
const auto& exprs = schedule.GetModule().GetExprs(); const auto& exprs = schedule.GetModule().GetExprs();
std::stringstream module_stream; std::stringstream module_stream;
for (auto i = 0; i < exprs.size(); ++i) { for (auto i = 0; i < exprs.size(); ++i) {
module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr " << i << "\n"; module_stream << "Expr " << i << " {\n"
<< exprs.at(i) << "\n} // end Expr " << i << "\n";
} }
return module_stream.str(); return module_stream.str();
} }
ir::Module TestAutoGenRuleBase::BuildIRModule(const ir::IRSchedule& schedule) { ir::Module TestAutoGenRuleBase::BuildIRModule(const ir::IRSchedule& schedule) {
auto&& updated_bodys = schedule.GetModule().GetExprs(); auto&& updated_bodys = schedule.GetModule().GetExprs();
CHECK_EQ(lowered_funcs_.size(), updated_bodys.size()) << "associated exprs size not equal"; CHECK_EQ(lowered_funcs_.size(), updated_bodys.size())
<< "associated exprs size not equal";
ir::Module::Builder builder("test_bulder", this->target_); ir::Module::Builder builder("test_bulder", this->target_);
for (int i = 0; i < lowered_funcs_.size(); ++i) { for (int i = 0; i < lowered_funcs_.size(); ++i) {
ir::Expr func_body = updated_bodys.at(i); ir::Expr func_body = updated_bodys.at(i);
const ir::LoweredFunc& ori_func = lowered_funcs_.at(i); const ir::LoweredFunc& ori_func = lowered_funcs_.at(i);
auto&& new_func = UpdateFuncWithNewBody(target_, ori_func, func_body); auto&& new_func = UpdateFuncWithNewBody(target_, ori_func, func_body);
builder.AddFunction(new_func); builder.AddFunction(new_func);
} }
...@@ -102,20 +110,24 @@ std::string TestAutoGenRuleBase::GenSourceCode(const ir::Module& ir_module) { ...@@ -102,20 +110,24 @@ std::string TestAutoGenRuleBase::GenSourceCode(const ir::Module& ir_module) {
if (target_ == common::DefaultNVGPUTarget()) { if (target_ == common::DefaultNVGPUTarget()) {
codegen = std::make_unique<backends::CodeGenCUDA_Dev>(this->target_); codegen = std::make_unique<backends::CodeGenCUDA_Dev>(this->target_);
} else { } else {
codegen = std::make_unique<backends::CodeGenCX86>(this->target_, CodeGenCX86::Feature::AVX512); codegen = std::make_unique<backends::CodeGenCX86>(
this->target_, CodeGenCX86::Feature::AVX512);
} }
#else #else
codegen = std::make_unique<backends::CodeGenCX86>(this->target_, CodeGenCX86::Feature::AVX512); codegen = std::make_unique<backends::CodeGenCX86>(
this->target_, CodeGenCX86::Feature::AVX512);
#endif #endif
codegen->SetInlineBuiltinCodes(false); codegen->SetInlineBuiltinCodes(false);
return codegen->Compile(ir_module, CodeGenC::OutputKind::CImpl); return codegen->Compile(ir_module, CodeGenC::OutputKind::CImpl);
} }
raw_func_type TestAutoGenRuleBase::GenExecutableKernel(const ir::Module& ir_module) { raw_func_type TestAutoGenRuleBase::GenExecutableKernel(
const ir::Module& ir_module) {
auto&& func_name = lowered_funcs_.front()->name; auto&& func_name = lowered_funcs_.front()->name;
// Compile to machine code // Compile to machine code
backend_compier_->Build(ir_module); backend_compier_->Build(ir_module);
auto test_func_ptr = reinterpret_cast<void (*)(void**, int32_t)>(backend_compier_->Lookup(func_name)); auto test_func_ptr = reinterpret_cast<void (*)(void**, int32_t)>(
backend_compier_->Lookup(func_name));
return test_func_ptr; return test_func_ptr;
} }
...@@ -138,15 +150,19 @@ void MemoryCopy(const float* src, float* dst, int numel, std::string type) { ...@@ -138,15 +150,19 @@ void MemoryCopy(const float* src, float* dst, int numel, std::string type) {
} }
} }
void AddDataToScope( void AddDataToScope(Scope* scope,
Scope* scope, const common::Target& target, float* data_ptr, std::string name, const std::vector<int>& shape) { const common::Target& target,
auto* var = scope->Var<Tensor>(name); float* data_ptr,
std::string name,
const std::vector<int>& shape) {
auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var); auto& tensor = absl::get<Tensor>(*var);
CHECK(shape.size()) << "The size of shape can not be 0."; CHECK(shape.size()) << "The size of shape can not be 0.";
Shape cinn_shape(shape); Shape cinn_shape(shape);
tensor->Resize(cinn_shape); tensor->Resize(cinn_shape);
auto* tgt_data_ptr = tensor->mutable_data<float>(target); auto* tgt_data_ptr = tensor->mutable_data<float>(target);
std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; std::string mem_cpy_type =
target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost";
MemoryCopy(data_ptr, tgt_data_ptr, cinn_shape.numel(), mem_cpy_type); MemoryCopy(data_ptr, tgt_data_ptr, cinn_shape.numel(), mem_cpy_type);
} }
...@@ -159,16 +175,20 @@ void CheckResult(raw_func_type test_func, ...@@ -159,16 +175,20 @@ void CheckResult(raw_func_type test_func,
const common::Target& target) { const common::Target& target) {
CHECK(input_names.size()) << "The number of inputs must be greater than 0."; CHECK(input_names.size()) << "The number of inputs must be greater than 0.";
CHECK(output_names.size()) << "The number of outputs must be greater than 0."; CHECK(output_names.size()) << "The number of outputs must be greater than 0.";
CHECK_EQ(input_names.size(), input_shapes.size()) << "The quantity of input_names and input_shapes must be equal."; CHECK_EQ(input_names.size(), input_shapes.size())
<< "The quantity of input_names and input_shapes must be equal.";
CHECK_EQ(output_names.size(), output_shapes.size()) CHECK_EQ(output_names.size(), output_shapes.size())
<< "The quantity of output_names and output_shapes must be equal."; << "The quantity of output_names and output_shapes must be equal.";
// Initialize data // Initialize data
std::vector<float*> input_data_ptrs(input_names.size()); std::vector<float*> input_data_ptrs(input_names.size());
for (int i = 0; i < input_shapes.size(); ++i) { for (int i = 0; i < input_shapes.size(); ++i) {
int input_data_numel = int input_data_numel = std::accumulate(
std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) { return a * b; }); input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) {
input_data_ptrs[i] = reinterpret_cast<float*>(malloc(input_data_numel * sizeof(float))); return a * b;
});
input_data_ptrs[i] =
reinterpret_cast<float*>(malloc(input_data_numel * sizeof(float)));
for (int j = 0; j < input_data_numel; ++j) { for (int j = 0; j < input_data_numel; ++j) {
input_data_ptrs[i][j] = (rand() * 1.f) / RAND_MAX; input_data_ptrs[i][j] = (rand() * 1.f) / RAND_MAX;
} }
...@@ -177,24 +197,35 @@ void CheckResult(raw_func_type test_func, ...@@ -177,24 +197,35 @@ void CheckResult(raw_func_type test_func,
std::vector<float*> expected_output_data_ptrs(output_names.size()); std::vector<float*> expected_output_data_ptrs(output_names.size());
std::vector<int> output_data_numels(output_shapes.size()); std::vector<int> output_data_numels(output_shapes.size());
for (int i = 0; i < output_shapes.size(); ++i) { for (int i = 0; i < output_shapes.size(); ++i) {
output_data_numels[i] = output_data_numels[i] = std::accumulate(
std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) { return a * b; }); output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) {
test_output_data_ptrs[i] = reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float))); return a * b;
});
test_output_data_ptrs[i] =
reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(test_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); memset(test_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
expected_output_data_ptrs[i] = reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float))); expected_output_data_ptrs[i] =
memset(expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(
expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
} }
auto launch_kernel_fn = [&](raw_func_type& raw_func, std::vector<float*>& output_data_ptrs) { auto launch_kernel_fn = [&](raw_func_type& raw_func,
std::vector<float*>& output_data_ptrs) {
// Initialize scope // Initialize scope
Scope scope; Scope scope;
// Initialize input data in scope. // Initialize input data in scope.
for (int i = 0; i < input_names.size(); ++i) { for (int i = 0; i < input_names.size(); ++i) {
AddDataToScope(&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]); AddDataToScope(
&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]);
} }
// Initialize output data in scope. // Initialize output data in scope.
for (int i = 0; i < output_names.size(); ++i) { for (int i = 0; i < output_names.size(); ++i) {
AddDataToScope(&scope, target, output_data_ptrs[i], output_names[i], output_shapes[i]); AddDataToScope(&scope,
target,
output_data_ptrs[i],
output_names[i],
output_shapes[i]);
} }
// Create Instruction and run // Create Instruction and run
...@@ -207,9 +238,12 @@ void CheckResult(raw_func_type test_func, ...@@ -207,9 +238,12 @@ void CheckResult(raw_func_type test_func,
// data // data
for (int i = 0; i < output_names.size(); ++i) { for (int i = 0; i < output_names.size(); ++i) {
const float* result_ptr = scope.GetTensor(output_names[i])->data<float>(); const float* result_ptr = scope.GetTensor(output_names[i])->data<float>();
std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; std::string mem_cpy_type = target == common::DefaultNVGPUTarget()
MemoryCopy(result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type); ? "DeviceToHost"
: "HostToHost";
MemoryCopy(
result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type);
} }
}; };
...@@ -220,7 +254,8 @@ void CheckResult(raw_func_type test_func, ...@@ -220,7 +254,8 @@ void CheckResult(raw_func_type test_func,
// Check result // Check result
for (int i = 0; i < output_shapes.size(); ++i) { for (int i = 0; i < output_shapes.size(); ++i) {
for (int j = 0; j < output_data_numels[i]; ++j) { for (int j = 0; j < output_data_numels[i]; ++j) {
ASSERT_NEAR(test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4); ASSERT_NEAR(
test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4);
} }
} }
......
...@@ -47,15 +47,18 @@ class TestAutoGenRuleBase : public ::testing::Test { ...@@ -47,15 +47,18 @@ class TestAutoGenRuleBase : public ::testing::Test {
// Initialize context for specified target // Initialize context for specified target
void Initialize(const common::Target& target); void Initialize(const common::Target& target);
// construct an ir::IRSchedule by lowering the specified for following AutoGenRule test // construct an ir::IRSchedule by lowering the specified for following
ir::IRSchedule MakeIRSchedule(const frontend::Program& test_program, // AutoGenRule test
utils::LinearRandomEngine::StateType rand_seed = -1, ir::IRSchedule MakeIRSchedule(
bool apply_manual_schedule = false); const frontend::Program& test_program,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool apply_manual_schedule = false);
// Get the IR of bodies in IRSchedule // Get the IR of bodies in IRSchedule
std::string GetIR(const ir::IRSchedule& schedule); std::string GetIR(const ir::IRSchedule& schedule);
// build ir::Module from the original lowered funcs with their bodies updated by the schedule // build ir::Module from the original lowered funcs with their bodies updated
// by the schedule
ir::Module BuildIRModule(const ir::IRSchedule& schedule); ir::Module BuildIRModule(const ir::IRSchedule& schedule);
// generate source code with the built ir module // generate source code with the built ir module
...@@ -75,9 +78,12 @@ class TestAutoGenRuleBase : public ::testing::Test { ...@@ -75,9 +78,12 @@ class TestAutoGenRuleBase : public ::testing::Test {
* @params-2: Expected function pointer for comparison. * @params-2: Expected function pointer for comparison.
* @params-3: Names of input data. * @params-3: Names of input data.
* @params-4: Names of output data. * @params-4: Names of output data.
* @params-5: Shapes of the input data, each input corresponds to a std::vector<int>. * @params-5: Shapes of the input data, each input corresponds to a
* @params-6: Shapes of the output data, each output corresponds to a std::vector<int>. * std::vector<int>.
* @params-7: The Target expressing computing platform and architecture of the function to be tested. * @params-6: Shapes of the output data, each output corresponds to a
* std::vector<int>.
* @params-7: The Target expressing computing platform and architecture of the
* function to be tested.
* @return: void * @return: void
*/ */
void CheckResult(raw_func_type test_func, void CheckResult(raw_func_type test_func,
......
...@@ -26,24 +26,30 @@ namespace auto_schedule { ...@@ -26,24 +26,30 @@ namespace auto_schedule {
class SearchState; class SearchState;
// Select the next block to be operated for SearchState during the search process // Select the next block to be operated for SearchState during the search
// process
class BlockSampler { class BlockSampler {
public: public:
/** /**
* @brief Create a BlockSampler with the specific strategy name and necessary construct parameters. * @brief Create a BlockSampler with the specific strategy name and necessary
* construct parameters.
* @param all_blocks All possible blocks to be sampled. * @param all_blocks All possible blocks to be sampled.
* @param default_remove_policy The default option to determine whether to delete the next block after selecting it. * @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The block sampling strategy. * @param strategy The block sampling strategy.
* Currently, the available strategies are "traversal" and "probabilistic", * Currently, the available strategies are "traversal" and
* where "traversal" means to select blocks one by one until all blocks are traversed, * "probabilistic", where "traversal" means to select blocks one by one until
* and "probabilistic" means randomly picking blocks according to the given distribution. * all blocks are traversed, and "probabilistic" means randomly picking blocks
* @param weights Used for the probabilistic policy, giving each candidate a weight. * according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/ */
static std::unique_ptr<BlockSampler> Make(const std::vector<ir::Expr>& all_blocks, static std::unique_ptr<BlockSampler> Make(
bool default_remove_policy = true, const std::vector<ir::Expr>& all_blocks,
const std::string& strategy = "traversal", bool default_remove_policy = true,
utils::LinearRandomEngine::StateType rand_seed = 0, const std::string& strategy = "traversal",
const std::vector<int>& weights = {}); utils::LinearRandomEngine::StateType rand_seed = 0,
const std::vector<int>& weights = {});
// Return the name of sample strategy // Return the name of sample strategy
virtual const char* Name() const = 0; virtual const char* Name() const = 0;
...@@ -56,18 +62,22 @@ class BlockSampler { ...@@ -56,18 +62,22 @@ class BlockSampler {
protected: protected:
// A BlockSampler object should be created with the static function Make() // A BlockSampler object should be created with the static function Make()
BlockSampler(const std::vector<ir::Expr>& all_blocks, bool default_remove_policy); BlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy);
// Select a block to apply rule // Select a block to apply rule
// The param remove is used to determine whether to delete the next block after selecting it, // The param remove is used to determine whether to delete the next block
// If remove == true, it will not be sampled in the future. // after selecting it, If remove == true, it will not be sampled in the
// future.
virtual std::string NextBlock(bool remove) = 0; virtual std::string NextBlock(bool remove) = 0;
// The names of all blocks // The names of all blocks
// Because the Block Expr will be changed in the search process, the name is saved for indexing // Because the Block Expr will be changed in the search process, the name is
// saved for indexing
std::vector<std::string> all_blocks_; std::vector<std::string> all_blocks_;
// The default policy to determine whether to delete the next block after selecting it. // The default policy to determine whether to delete the next block after
// selecting it.
bool default_remove_policy_; bool default_remove_policy_;
}; };
...@@ -75,7 +85,8 @@ class BlockSampler { ...@@ -75,7 +85,8 @@ class BlockSampler {
// witch means to select blocks one by one until all blocks are traversed. // witch means to select blocks one by one until all blocks are traversed.
class TraversalBlockSampler : public BlockSampler { class TraversalBlockSampler : public BlockSampler {
public: public:
TraversalBlockSampler(const std::vector<ir::Expr>& all_blocks, bool default_remove_policy) TraversalBlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy)
: BlockSampler(all_blocks, default_remove_policy), cur_idx_(0) {} : BlockSampler(all_blocks, default_remove_policy), cur_idx_(0) {}
const char* Name() const override { return "traversal"; } const char* Name() const override { return "traversal"; }
...@@ -96,7 +107,7 @@ class ProbabilisticBlockSampler : public BlockSampler { ...@@ -96,7 +107,7 @@ class ProbabilisticBlockSampler : public BlockSampler {
ProbabilisticBlockSampler(const std::vector<ir::Expr>& all_blocks, ProbabilisticBlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy, bool default_remove_policy,
utils::LinearRandomEngine::StateType rand_seed = 0, utils::LinearRandomEngine::StateType rand_seed = 0,
const std::vector<int>& weights = {}); const std::vector<int>& weights = {});
const char* Name() const override { return "probabilistic"; } const char* Name() const override { return "probabilistic"; }
......
...@@ -24,7 +24,8 @@ namespace auto_schedule { ...@@ -24,7 +24,8 @@ namespace auto_schedule {
std::vector<ir::Expr> CreateTestBlocks() { std::vector<ir::Expr> CreateTestBlocks() {
std::vector<ir::Expr> blocks; std::vector<ir::Expr> blocks;
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
ir::Expr block = ir::ScheduleBlock::Make({}, {}, {}, "block_" + std::to_string(i), ir::Expr()); ir::Expr block = ir::ScheduleBlock::Make(
{}, {}, {}, "block_" + std::to_string(i), ir::Expr());
blocks.push_back(ir::ScheduleBlockRealize::Make({}, block)); blocks.push_back(ir::ScheduleBlockRealize::Make({}, block));
} }
return blocks; return blocks;
...@@ -32,9 +33,11 @@ std::vector<ir::Expr> CreateTestBlocks() { ...@@ -32,9 +33,11 @@ std::vector<ir::Expr> CreateTestBlocks() {
TEST(BlockSampler, Make) { TEST(BlockSampler, Make) {
std::vector<ir::Expr> mock_blocks = CreateTestBlocks(); std::vector<ir::Expr> mock_blocks = CreateTestBlocks();
auto traversal_block_sampler = BlockSampler::Make(mock_blocks, true, "traversal"); auto traversal_block_sampler =
BlockSampler::Make(mock_blocks, true, "traversal");
ASSERT_STREQ(traversal_block_sampler->Name(), "traversal"); ASSERT_STREQ(traversal_block_sampler->Name(), "traversal");
auto probabilistic_block_sampler = BlockSampler::Make(mock_blocks, true, "probabilistic"); auto probabilistic_block_sampler =
BlockSampler::Make(mock_blocks, true, "probabilistic");
ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic"); ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic");
} }
...@@ -54,15 +57,17 @@ TEST(TraversalBlockSampler, NextBlock) { ...@@ -54,15 +57,17 @@ TEST(TraversalBlockSampler, NextBlock) {
} }
TEST(ProbabilisticBlockSampler, NextBlock) { TEST(ProbabilisticBlockSampler, NextBlock) {
std::vector<ir::Expr> blocks = CreateTestBlocks(); std::vector<ir::Expr> blocks = CreateTestBlocks();
auto probabilistic_block_sampler = BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1}); auto probabilistic_block_sampler =
BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1});
std::string block_name; std::string block_name;
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
block_name = probabilistic_block_sampler->NextBlock(); block_name = probabilistic_block_sampler->NextBlock();
VLOG(6) << "next block name: " << block_name; VLOG(6) << "next block name: " << block_name;
} }
probabilistic_block_sampler = BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1}); probabilistic_block_sampler =
BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1});
probabilistic_block_sampler->NextBlock(); probabilistic_block_sampler->NextBlock();
probabilistic_block_sampler->NextBlock(); probabilistic_block_sampler->NextBlock();
probabilistic_block_sampler->NextBlock(); probabilistic_block_sampler->NextBlock();
......
...@@ -30,20 +30,25 @@ class SearchState; ...@@ -30,20 +30,25 @@ class SearchState;
class RuleSampler { class RuleSampler {
public: public:
/** /**
* @brief Create a RuleSampler with the specific strategy name and necessary construct parameters. * @brief Create a RuleSampler with the specific strategy name and necessary
* construct parameters.
* @param potential_rules All possible rules to be sampled. * @param potential_rules All possible rules to be sampled.
* @param default_remove_policy The default option to determine whether to delete the next block after selecting it. * @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The rule sampling strategy. * @param strategy The rule sampling strategy.
* Currently, the available strategies are "traversal" and "probabilistic", * Currently, the available strategies are "traversal" and
* where "traversal" means to select rules one by one until all rules are traversed, * "probabilistic", where "traversal" means to select rules one by one until
* and "probabilistic" means randomly picking rules according to the given distribution. * all rules are traversed, and "probabilistic" means randomly picking rules
* @param weights Used for the probabilistic policy, giving each candidate a weight. * according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/ */
static std::unique_ptr<RuleSampler> Make(const std::vector<AutoGenRule*>& potential_rules, static std::unique_ptr<RuleSampler> Make(
bool default_remove_policy = true, const std::vector<AutoGenRule*>& potential_rules,
const std::string& strategy = "traversal", bool default_remove_policy = true,
utils::LinearRandomEngine::StateType rand_seed = 0, const std::string& strategy = "traversal",
const std::vector<int>& weights = {}); utils::LinearRandomEngine::StateType rand_seed = 0,
const std::vector<int>& weights = {});
// Return the name of sample strategy // Return the name of sample strategy
virtual const char* Name() const = 0; virtual const char* Name() const = 0;
...@@ -55,18 +60,21 @@ class RuleSampler { ...@@ -55,18 +60,21 @@ class RuleSampler {
protected: protected:
// A RuleSampler object should be created with the static function Make() // A RuleSampler object should be created with the static function Make()
RuleSampler(const std::vector<AutoGenRule*>& potential_rules, bool default_remove_policy) RuleSampler(const std::vector<AutoGenRule*>& potential_rules,
: potential_rules_(&potential_rules), default_remove_policy_(default_remove_policy) {} bool default_remove_policy)
: potential_rules_(&potential_rules),
default_remove_policy_(default_remove_policy) {}
// Select a rule to apply. // Select a rule to apply.
// The param remove is used to determine whether to delete the next rule after selecting it, // The param remove is used to determine whether to delete the next rule after
// If remove == true, it will not be sampled in the future. // selecting it, If remove == true, it will not be sampled in the future.
virtual AutoGenRule* NextRule(bool remove) = 0; virtual AutoGenRule* NextRule(bool remove) = 0;
// The pointer refers to all potential rules // The pointer refers to all potential rules
const std::vector<AutoGenRule*>* potential_rules_; const std::vector<AutoGenRule*>* potential_rules_;
// The default policy to determine whether to delete the next rule after selecting it. // The default policy to determine whether to delete the next rule after
// selecting it.
bool default_remove_policy_; bool default_remove_policy_;
}; };
...@@ -74,7 +82,8 @@ class RuleSampler { ...@@ -74,7 +82,8 @@ class RuleSampler {
// witch means to select rules one by one until all rules are traversed. // witch means to select rules one by one until all rules are traversed.
class TraversalRuleSampler : public RuleSampler { class TraversalRuleSampler : public RuleSampler {
public: public:
TraversalRuleSampler(const std::vector<AutoGenRule*>& potential_rules, bool default_remove_policy) TraversalRuleSampler(const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy)
: RuleSampler(potential_rules, default_remove_policy), cur_idx_(0) {} : RuleSampler(potential_rules, default_remove_policy), cur_idx_(0) {}
const char* Name() const override { return "traversal"; } const char* Name() const override { return "traversal"; }
...@@ -95,7 +104,7 @@ class ProbabilisticRuleSampler : public RuleSampler { ...@@ -95,7 +104,7 @@ class ProbabilisticRuleSampler : public RuleSampler {
ProbabilisticRuleSampler(const std::vector<AutoGenRule*>& potential_rules, ProbabilisticRuleSampler(const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy, bool default_remove_policy,
utils::LinearRandomEngine::StateType rand_seed = 0, utils::LinearRandomEngine::StateType rand_seed = 0,
const std::vector<int>& weights = {}); const std::vector<int>& weights = {});
const char* Name() const override { return "probabilistic"; } const char* Name() const override { return "probabilistic"; }
......
...@@ -28,20 +28,23 @@ Target target = common::DefaultNVGPUTarget(); ...@@ -28,20 +28,23 @@ Target target = common::DefaultNVGPUTarget();
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
std::vector<AutoGenRule*> GenerateTestRules() { return {new AutoUnroll(target), new SkipRule(target)}; } std::vector<AutoGenRule*> GenerateTestRules() {
return {new AutoUnroll(target), new SkipRule(target)};
}
TEST(RuleSampler, Make) { TEST(RuleSampler, Make) {
std::vector<AutoGenRule*> rules = GenerateTestRules(); std::vector<AutoGenRule*> rules = GenerateTestRules();
auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal"); auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal");
ASSERT_STREQ(traversal_block_sampler->Name(), "traversal"); ASSERT_STREQ(traversal_block_sampler->Name(), "traversal");
auto probabilistic_block_sampler = RuleSampler::Make(rules, true, "probabilistic"); auto probabilistic_block_sampler =
RuleSampler::Make(rules, true, "probabilistic");
ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic"); ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic");
} }
TEST(TraversalRuleSampler, NextRule) { TEST(TraversalRuleSampler, NextRule) {
std::vector<AutoGenRule*> rules = GenerateTestRules(); std::vector<AutoGenRule*> rules = GenerateTestRules();
auto traversal_rule_sampler = RuleSampler::Make(rules, true, "traversal"); auto traversal_rule_sampler = RuleSampler::Make(rules, true, "traversal");
AutoGenRule* rule = traversal_rule_sampler->NextRule(); AutoGenRule* rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("AutoUnroll", rule->GetRuleName()); ASSERT_EQ("AutoUnroll", rule->GetRuleName());
rule = traversal_rule_sampler->NextRule(); rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("SkipRule", rule->GetRuleName()); ASSERT_EQ("SkipRule", rule->GetRuleName());
...@@ -50,7 +53,7 @@ TEST(TraversalRuleSampler, NextRule) { ...@@ -50,7 +53,7 @@ TEST(TraversalRuleSampler, NextRule) {
ASSERT_EQ("AutoUnroll", rule->GetRuleName()); ASSERT_EQ("AutoUnroll", rule->GetRuleName());
traversal_rule_sampler = RuleSampler::Make(rules, false, "traversal"); traversal_rule_sampler = RuleSampler::Make(rules, false, "traversal");
rule = traversal_rule_sampler->NextRule(); rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("AutoUnroll", rule->GetRuleName()); ASSERT_EQ("AutoUnroll", rule->GetRuleName());
rule = traversal_rule_sampler->NextRule(); rule = traversal_rule_sampler->NextRule();
ASSERT_EQ("AutoUnroll", rule->GetRuleName()); ASSERT_EQ("AutoUnroll", rule->GetRuleName());
...@@ -58,14 +61,16 @@ TEST(TraversalRuleSampler, NextRule) { ...@@ -58,14 +61,16 @@ TEST(TraversalRuleSampler, NextRule) {
TEST(ProbabilisticRuleSampler, NextRule) { TEST(ProbabilisticRuleSampler, NextRule) {
std::vector<AutoGenRule*> rules = GenerateTestRules(); std::vector<AutoGenRule*> rules = GenerateTestRules();
auto probabilistic_rule_sampler = RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1}); auto probabilistic_rule_sampler =
RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1});
AutoGenRule* rule; AutoGenRule* rule;
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
rule = probabilistic_rule_sampler->NextRule(); rule = probabilistic_rule_sampler->NextRule();
VLOG(6) << "next rule name: " << rule->GetRuleName(); VLOG(6) << "next rule name: " << rule->GetRuleName();
} }
probabilistic_rule_sampler = RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1}); probabilistic_rule_sampler =
RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1});
probabilistic_rule_sampler->NextRule(); probabilistic_rule_sampler->NextRule();
probabilistic_rule_sampler->NextRule(); probabilistic_rule_sampler->NextRule();
ASSERT_EQ(nullptr, probabilistic_rule_sampler->NextRule()); ASSERT_EQ(nullptr, probabilistic_rule_sampler->NextRule());
......
...@@ -39,18 +39,23 @@ DECLARE_bool(auto_schedule_use_cost_model); ...@@ -39,18 +39,23 @@ DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
SearchSpace::SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed) SearchSpace::SearchSpace(const TuneTask& tune_task,
: tune_task_(tune_task), rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) { utils::LinearRandomEngine::StateType rand_seed)
: tune_task_(tune_task),
rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) {
const auto& target = tune_task_.target; const auto& target = tune_task_.target;
// initialize a set of rules and they are commonly used by all states // initialize a set of rules and they are commonly used by all states
// TODO(zhhsplendid): pass correct output names to AutoInline // TODO(zhhsplendid): pass correct output names to AutoInline
// sketch_rules_.emplace_back(new AutoInline(target, tune_task_.output_names)); // sketch_rules_.emplace_back(new AutoInline(target,
sketch_rules_.emplace_back(new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch))); // tune_task_.output_names));
sketch_rules_.emplace_back(
new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch)));
sketch_rules_.emplace_back(new AutoUnroll(target)); sketch_rules_.emplace_back(new AutoUnroll(target));
sketch_rules_.emplace_back(new SkipRule(target)); sketch_rules_.emplace_back(new SkipRule(target));
} }
SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model) { SearchState SearchSpace::GetScheduleMutate(const SearchState& state,
const ExprCostModel& cost_model) {
bool has_manual_schedule = false; bool has_manual_schedule = false;
if (has_manual_schedule) { if (has_manual_schedule) {
SearchState ret = ManualScheduleMutate(state); SearchState ret = ManualScheduleMutate(state);
...@@ -58,9 +63,11 @@ SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprC ...@@ -58,9 +63,11 @@ SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprC
} }
SearchState ret = RandomScheduleMutate(state); SearchState ret = RandomScheduleMutate(state);
if (FLAGS_auto_schedule_use_cost_model) { if (FLAGS_auto_schedule_use_cost_model) {
ret->predicted_cost = cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target); ret->predicted_cost =
cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target);
} }
VLOG(4) << JoinStatesDebugString("SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5)); VLOG(4) << JoinStatesDebugString(
"SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5));
return ret; return ret;
} }
...@@ -77,9 +84,10 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { ...@@ -77,9 +84,10 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
SearchState ret(state); SearchState ret(state);
std::vector<RuleApplyType> apply_types(ret->applicable_rules.size()); std::vector<RuleApplyType> apply_types(ret->applicable_rules.size());
for (int idx = 0; idx != ret->applicable_rules.size(); ++idx) { for (int idx = 0; idx != ret->applicable_rules.size(); ++idx) {
AutoGenRule* rule = ret->applicable_rules.at(idx); AutoGenRule* rule = ret->applicable_rules.at(idx);
RuleApplyType apply_type = rule->Init(&ret->ir_schedule); RuleApplyType apply_type = rule->Init(&ret->ir_schedule);
VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "=" << static_cast<int>(apply_type); VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "="
<< static_cast<int>(apply_type);
apply_types[idx] = apply_type; apply_types[idx] = apply_type;
if (apply_type != RuleApplyType::kCannotApply) { if (apply_type != RuleApplyType::kCannotApply) {
weight_to_rule_index[cur_weight] = idx; weight_to_rule_index[cur_weight] = idx;
...@@ -94,7 +102,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { ...@@ -94,7 +102,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
} }
// 3. Sample a schedule on the distribution // 3. Sample a schedule on the distribution
int sample_weighted_index = utils::SampleUniformInt(0, cur_weight, &rand_seed_); int sample_weighted_index =
utils::SampleUniformInt(0, cur_weight, &rand_seed_);
auto iter = weight_to_rule_index.upper_bound(sample_weighted_index); auto iter = weight_to_rule_index.upper_bound(sample_weighted_index);
--iter; --iter;
...@@ -102,13 +111,15 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { ...@@ -102,13 +111,15 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
int sample_rule_index = iter->second; int sample_rule_index = iter->second;
CHECK_LT(sample_rule_index, ret->applicable_rules.size()); CHECK_LT(sample_rule_index, ret->applicable_rules.size());
AutoGenRule* sample_rule = ret->applicable_rules.at(sample_rule_index); AutoGenRule* sample_rule = ret->applicable_rules.at(sample_rule_index);
VLOG(7) << "Apply rule: " << sample_rule->GetRuleName() << " with index=" << sample_weighted_index - iter->first; VLOG(7) << "Apply rule: " << sample_rule->GetRuleName()
<< " with index=" << sample_weighted_index - iter->first;
// 4. Apply the schedule change // 4. Apply the schedule change
sample_rule->Apply(sample_weighted_index - iter->first); sample_rule->Apply(sample_weighted_index - iter->first);
// 5. Remove the rule after applying it // 5. Remove the rule after applying it
if (apply_types.at(sample_rule_index) != RuleApplyType::kCannotApply) { if (apply_types.at(sample_rule_index) != RuleApplyType::kCannotApply) {
ret->applicable_rules.erase(ret->applicable_rules.begin() + sample_rule_index); ret->applicable_rules.erase(ret->applicable_rules.begin() +
sample_rule_index);
} }
return ret; return ret;
...@@ -116,17 +127,20 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { ...@@ -116,17 +127,20 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) { std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
VLOG(5) << "SearchSpace::GetRandomInitialSketch with num=" << num; VLOG(5) << "SearchSpace::GetRandomInitialSketch with num=" << num;
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), ir::IRSchedule init_schedule(
utils::ForkRandomState(&rand_seed_)); ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_));
std::vector<AutoGenRule*> init_rules; std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end(), std::back_inserter(init_rules), [](const auto& rule) { std::transform(sketch_rules_.begin(),
return rule.get(); sketch_rules_.end(),
}); std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
std::vector<SearchState> result; std::vector<SearchState> result;
while (result.size() < num) { while (result.size() < num) {
SearchState state(init_schedule, SearchState::NOT_INIT_COST, init_rules); SearchState state(init_schedule, SearchState::NOT_INIT_COST, init_rules);
for (int i = 0; i < init_sketch_random_depth_; ++i) { for (int i = 0; i < init_sketch_random_depth_; ++i) {
VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: " << i; VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: "
<< i;
state = RandomScheduleMutate(state); state = RandomScheduleMutate(state);
if (state->applicable_rules.empty()) { if (state->applicable_rules.empty()) {
break; break;
...@@ -134,7 +148,9 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) { ...@@ -134,7 +148,9 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
} }
VLOG(5) << JoinStatesDebugString( VLOG(5) << JoinStatesDebugString(
"SearchSpace::GetRandomInitialSketch-New_Sketch", {state}, /*verbose=*/VLOG_IS_ON(6)); "SearchSpace::GetRandomInitialSketch-New_Sketch",
{state},
/*verbose=*/VLOG_IS_ON(6));
result.emplace_back(std::move(state)); result.emplace_back(std::move(state));
} }
return result; return result;
...@@ -142,24 +158,28 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) { ...@@ -142,24 +158,28 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() { std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
VLOG(5) << "SearchSpace::InitSketchWithRandomPrunedStrategy"; VLOG(5) << "SearchSpace::InitSketchWithRandomPrunedStrategy";
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), ir::IRSchedule init_schedule(
utils::ForkRandomState(&rand_seed_)); ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
auto all_blocks = init_schedule.GetAllBlocks(); utils::ForkRandomState(&rand_seed_));
auto block_sampler = BlockSampler::Make(all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); auto all_blocks = init_schedule.GetAllBlocks();
auto block_sampler = BlockSampler::Make(
all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_));
std::vector<AutoGenRule*> init_rules; std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) { std::transform(sketch_rules_.begin(),
return rule.get(); sketch_rules_.end() - 1,
}); std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
CHECK(init_rules.size() > 0) << "number of init rules cannot be 0"; CHECK(init_rules.size() > 0) << "number of init rules cannot be 0";
SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {}); SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {});
std::vector<SearchState> states_buf1{init_state}, states_buf2; std::vector<SearchState> states_buf1{init_state}, states_buf2;
std::vector<SearchState>* p_states_cur = &states_buf1; std::vector<SearchState>* p_states_cur = &states_buf1;
std::vector<SearchState>* p_states_next = &states_buf2; std::vector<SearchState>* p_states_next = &states_buf2;
int total_steps = 0, steps; int total_steps = 0, steps;
std::string block_name; std::string block_name;
while ("" != (block_name = block_sampler->NextBlock()) && total_steps < init_sketch_random_depth_) { while ("" != (block_name = block_sampler->NextBlock()) &&
total_steps < init_sketch_random_depth_) {
steps = utils::SampleUniformInt(1, init_rules.size() + 1, &rand_seed_); steps = utils::SampleUniformInt(1, init_rules.size() + 1, &rand_seed_);
if (total_steps + steps > init_sketch_random_depth_) { if (total_steps + steps > init_sketch_random_depth_) {
steps = init_sketch_random_depth_ - total_steps; steps = init_sketch_random_depth_ - total_steps;
...@@ -167,51 +187,66 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() { ...@@ -167,51 +187,66 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
total_steps += steps; total_steps += steps;
p_states_next->clear(); p_states_next->clear();
for (const auto& state : *p_states_cur) { for (const auto& state : *p_states_cur) {
auto rule_sampler = RuleSampler::Make(init_rules, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); auto rule_sampler =
auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), steps, false, 1); RuleSampler::Make(init_rules,
p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end()); true,
"probabilistic",
utils::ForkRandomState(&rand_seed_));
auto new_states = ApplySketchRule(
state, block_name, rule_sampler.get(), steps, false, 1);
p_states_next->insert(
p_states_next->end(), new_states.begin(), new_states.end());
} }
std::swap(p_states_cur, p_states_next); std::swap(p_states_cur, p_states_next);
} }
VLOG(5) << JoinStatesDebugString( VLOG(5) << JoinStatesDebugString(
"SearchSpace::InitSketchWithRandomPrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6)); "SearchSpace::InitSketchWithRandomPrunedStrategy",
*p_states_cur,
/*verbose=*/VLOG_IS_ON(6));
return *p_states_cur; return *p_states_cur;
} }
std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() { std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() {
VLOG(5) << "SearchSpace::InitSketchWithRulePrunedStrategy"; VLOG(5) << "SearchSpace::InitSketchWithRulePrunedStrategy";
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), ir::IRSchedule init_schedule(
utils::ForkRandomState(&rand_seed_)); ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_));
auto all_blocks = init_schedule.GetAllBlocks(); auto all_blocks = init_schedule.GetAllBlocks();
std::reverse(all_blocks.begin(), all_blocks.end()); std::reverse(all_blocks.begin(), all_blocks.end());
auto block_sampler = BlockSampler::Make(all_blocks, true, "traversal"); auto block_sampler = BlockSampler::Make(all_blocks, true, "traversal");
std::vector<AutoGenRule*> init_rules; std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) { std::transform(sketch_rules_.begin(),
return rule.get(); sketch_rules_.end() - 1,
}); std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
CHECK(init_rules.size() > 0) << "number of init rules cannot be 0"; CHECK(init_rules.size() > 0) << "number of init rules cannot be 0";
SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {}); SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {});
std::vector<SearchState> states_buf1{init_state}, states_buf2; std::vector<SearchState> states_buf1{init_state}, states_buf2;
std::vector<SearchState>* p_states_cur = &states_buf1; std::vector<SearchState>* p_states_cur = &states_buf1;
std::vector<SearchState>* p_states_next = &states_buf2; std::vector<SearchState>* p_states_next = &states_buf2;
std::string block_name; std::string block_name;
while ("" != (block_name = block_sampler->NextBlock())) { while ("" != (block_name = block_sampler->NextBlock())) {
p_states_next->clear(); p_states_next->clear();
for (const auto& state : *p_states_cur) { for (const auto& state : *p_states_cur) {
auto rule_sampler = RuleSampler::Make(init_rules, true, "traversal"); auto rule_sampler = RuleSampler::Make(init_rules, true, "traversal");
auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), 0, true); auto new_states =
p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end()); ApplySketchRule(state, block_name, rule_sampler.get(), 0, true);
p_states_next->insert(
p_states_next->end(), new_states.begin(), new_states.end());
} }
std::swap(p_states_cur, p_states_next); std::swap(p_states_cur, p_states_next);
} }
VLOG(5) << JoinStatesDebugString( VLOG(5) << JoinStatesDebugString(
"SearchSpace::InitSketchWithRulePrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6)); "SearchSpace::InitSketchWithRulePrunedStrategy",
*p_states_cur,
/*verbose=*/VLOG_IS_ON(6));
return *p_states_cur; return *p_states_cur;
} }
std::vector<SearchState> SearchSpace::GenerateSketches(int num, const std::string& strategy) { std::vector<SearchState> SearchSpace::GenerateSketches(
int num, const std::string& strategy) {
VLOG(4) << "SearchSpace::GenerateSketches with num = " << num; VLOG(4) << "SearchSpace::GenerateSketches with num = " << num;
if (strategy == "random") { if (strategy == "random") {
...@@ -239,28 +274,33 @@ std::vector<SearchState> SearchSpace::GenerateSketches(int num, const std::strin ...@@ -239,28 +274,33 @@ std::vector<SearchState> SearchSpace::GenerateSketches(int num, const std::strin
} }
} }
} }
VLOG(4) << JoinStatesDebugString("SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5)); VLOG(4) << JoinStatesDebugString(
"SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5));
return result; return result;
} }
std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state, std::vector<SearchState> SearchSpace::ApplySketchRule(
const std::string& block_name, const SearchState& state,
RuleSampler* rule_sampler, const std::string& block_name,
int steps, RuleSampler* rule_sampler,
bool prune_by_rule, int steps,
double prune_probability) { bool prune_by_rule,
double prune_probability) {
std::list<SearchState> layer{state}; std::list<SearchState> layer{state};
int step = 0; int step = 0;
AutoGenRule* rule; AutoGenRule* rule;
// After determining a SearchState and a block, each rule has two possibilities: apply and not apply. // After determining a SearchState and a block, each rule has two
// In all transfer spaces, select a rule at each step, and collect all possible new states arrived by apply and not // possibilities: apply and not apply. In all transfer spaces, select a rule
// apply. This forms a tree, and we can use rule pruning or random pruning to reduce the number of sketches. // at each step, and collect all possible new states arrived by apply and not
// apply. This forms a tree, and we can use rule pruning or random pruning to
// reduce the number of sketches.
VLOG(6) << "Collect the states of all transfers within steps: " << steps; VLOG(6) << "Collect the states of all transfers within steps: " << steps;
while ((step++ < steps || steps == 0) && (rule = rule_sampler->NextRule())) { while ((step++ < steps || steps == 0) && (rule = rule_sampler->NextRule())) {
VLOG(7) << "step = " << step << ", rule: " << rule->GetRuleName(); VLOG(7) << "step = " << step << ", rule: " << rule->GetRuleName();
std::list<SearchState> new_states; std::list<SearchState> new_states;
int id = 0; int id = 0;
for (std::list<SearchState>::iterator iter = layer.begin(); iter != layer.end();) { for (std::list<SearchState>::iterator iter = layer.begin();
iter != layer.end();) {
// Some rules will reduce the number of blocks, such as AutoInline, // Some rules will reduce the number of blocks, such as AutoInline,
// so we need to check whether the SearchState still has the block. // so we need to check whether the SearchState still has the block.
if (!(*iter)->ir_schedule.HasBlock(block_name)) { if (!(*iter)->ir_schedule.HasBlock(block_name)) {
...@@ -268,21 +308,26 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state, ...@@ -268,21 +308,26 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
continue; continue;
} }
auto type = rule->AnalyseApplyType(*iter, block_name); auto type = rule->AnalyseApplyType(*iter, block_name);
VLOG(7) << "At SearchState " << ++id VLOG(7)
<< ", apply type = " << static_cast<typename std::underlying_type<RuleApplyType>::type>(type); << "At SearchState " << ++id << ", apply type = "
<< static_cast<typename std::underlying_type<RuleApplyType>::type>(
type);
// if cannot apply the rule, skip it // if cannot apply the rule, skip it
if (type == RuleApplyType::kCannotApply) { if (type == RuleApplyType::kCannotApply) {
++iter; ++iter;
continue; continue;
} }
// if can apply the rule, apply it and determine whether to prune the branch that do not apply // if can apply the rule, apply it and determine whether to prune the
std::vector<SearchState> tmp_states = rule->ApplyOnBlock(*iter, block_name); // branch that do not apply
std::vector<SearchState> tmp_states =
rule->ApplyOnBlock(*iter, block_name);
new_states.insert(new_states.end(), tmp_states.begin(), tmp_states.end()); new_states.insert(new_states.end(), tmp_states.begin(), tmp_states.end());
bool need_prune = false; bool need_prune = false;
if (prune_by_rule) { if (prune_by_rule) {
need_prune = (type == RuleApplyType::kApplyAndPruneOtherRules); need_prune = (type == RuleApplyType::kApplyAndPruneOtherRules);
} else { } else {
need_prune = (utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability); need_prune =
(utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability);
} }
if (need_prune) { if (need_prune) {
iter = layer.erase(iter); iter = layer.erase(iter);
...@@ -290,10 +335,12 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state, ...@@ -290,10 +335,12 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
++iter; ++iter;
} }
} }
VLOG(7) << "apply on block: " << block_name << ", generate " << new_states.size() << " new states at step " << step; VLOG(7) << "apply on block: " << block_name << ", generate "
<< new_states.size() << " new states at step " << step;
layer.splice(layer.end(), std::move(new_states)); layer.splice(layer.end(), std::move(new_states));
} }
VLOG(6) << "apply on block: " << block_name << ", generate " << layer.size() - 1 << " more states at all"; VLOG(6) << "apply on block: " << block_name << ", generate "
<< layer.size() - 1 << " more states at all";
return std::vector<SearchState>(layer.begin(), layer.end()); return std::vector<SearchState>(layer.begin(), layer.end());
} }
......
...@@ -40,24 +40,31 @@ namespace auto_schedule { ...@@ -40,24 +40,31 @@ namespace auto_schedule {
*/ */
class SearchSpace { class SearchSpace {
public: public:
SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed = -1); SearchSpace(const TuneTask& tune_task,
utils::LinearRandomEngine::StateType rand_seed = -1);
// Sketch mutate, returns the mutated ModuleExpr and estimited cost // Sketch mutate, returns the mutated ModuleExpr and estimited cost
virtual SearchState GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model); virtual SearchState GetScheduleMutate(const SearchState& state,
const ExprCostModel& cost_model);
/** /**
* \brief Generate sketch as initial population of evolutionary search. * \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate. * @param num The number of sketches to generate.
* @param strategy The strategy to generate sketchs, * @param strategy The strategy to generate sketchs,
* Current optional strategies are "rule_prune" or "random_prune" or "random". * Current optional strategies are "rule_prune" or "random_prune" or
* - "rule_prune": will use rules to prune and generate sketches as efficiently as possible. * "random".
* - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches, * - "rule_prune": will use rules to prune and generate sketches as
* and supports the function of a rule returning multiple SearchStates and random pruning by probability. * efficiently as possible.
* - "random": will randomly select a block and a rule to apply and repeat this step several times, * - "random_prune": will use the new interface ApplySketchRules() to simulate
* however, each rule can only be used on one SearchState at most once. * the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketchs. * @return Generated sketchs.
*/ */
virtual std::vector<SearchState> GenerateSketches(int num, const std::string& strategy); virtual std::vector<SearchState> GenerateSketches(
int num, const std::string& strategy);
private: private:
// TODO(zhhsplendid): mutate by manual schedule. // TODO(zhhsplendid): mutate by manual schedule.
...@@ -69,20 +76,24 @@ class SearchSpace { ...@@ -69,20 +76,24 @@ class SearchSpace {
// Generate num sketchs, each with several rounds of SketchMutate // Generate num sketchs, each with several rounds of SketchMutate
std::vector<SearchState> InitSketchWithRandomStrategy(int num); std::vector<SearchState> InitSketchWithRandomStrategy(int num);
// Generate sketch pruned randomly as initial population of evolutionary search // Generate sketch pruned randomly as initial population of evolutionary
// search
std::vector<SearchState> InitSketchWithRandomPrunedStrategy(); std::vector<SearchState> InitSketchWithRandomPrunedStrategy();
// Generate sketch pruned by rules as initial population of evolutionary search // Generate sketch pruned by rules as initial population of evolutionary
// search
std::vector<SearchState> InitSketchWithRulePrunedStrategy(); std::vector<SearchState> InitSketchWithRulePrunedStrategy();
/** /**
* @brief Collect the new states that may be transferred to after applying several rules on a block from a certain * @brief Collect the new states that may be transferred to after applying
* state. * several rules on a block from a certain state.
* @param state Starting point of state transition. * @param state Starting point of state transition.
* @param block_name Name of the block to apply the rules to. * @param block_name Name of the block to apply the rules to.
* @param rule_sampler Sampler that samples the new rule to apply on the block. * @param rule_sampler Sampler that samples the new rule to apply on the
* block.
* @param steps Number of steps to apply the rule. * @param steps Number of steps to apply the rule.
* @param prune_by_rule If true, prune the state transition tree by rule, otherwise prune randomly. * @param prune_by_rule If true, prune the state transition tree by rule,
* otherwise prune randomly.
* @param prune_probability Pruning probability of random pruning. * @param prune_probability Pruning probability of random pruning.
*/ */
std::vector<SearchState> ApplySketchRule(const SearchState& state, std::vector<SearchState> ApplySketchRule(const SearchState& state,
......
...@@ -35,7 +35,9 @@ class SearchState : public common::Shared<_SearchState_> { ...@@ -35,7 +35,9 @@ class SearchState : public common::Shared<_SearchState_> {
public: public:
SearchState() = default; SearchState() = default;
// create a new SearchState // create a new SearchState
explicit SearchState(ir::IRSchedule ir_sch, float cost = NOT_INIT_COST, const std::vector<AutoGenRule*>& rules = {}); explicit SearchState(ir::IRSchedule ir_sch,
float cost = NOT_INIT_COST,
const std::vector<AutoGenRule*>& rules = {});
// Constant standing for a cost not being initialized // Constant standing for a cost not being initialized
static constexpr float NOT_INIT_COST = std::numeric_limits<float>::max(); static constexpr float NOT_INIT_COST = std::numeric_limits<float>::max();
...@@ -62,12 +64,14 @@ struct _SearchState_ : public common::Object { ...@@ -62,12 +64,14 @@ struct _SearchState_ : public common::Object {
static constexpr char* __type_info__ = "auto_schedule_state"; static constexpr char* __type_info__ = "auto_schedule_state";
}; };
// SearchStateHash hash functor that visits every AST node and combine their hash of node_type in dfs order // SearchStateHash hash functor that visits every AST node and combine their
// hash of node_type in dfs order
struct SearchStateHash { struct SearchStateHash {
size_t operator()(const SearchState& s) const; size_t operator()(const SearchState& s) const;
}; };
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST struct and fields // SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST
// struct and fields
struct SearchStateEqual { struct SearchStateEqual {
bool operator()(const SearchState& lhs, const SearchState& rhs) const; bool operator()(const SearchState& lhs, const SearchState& rhs) const;
}; };
......
...@@ -34,11 +34,14 @@ class MutateRule { ...@@ -34,11 +34,14 @@ class MutateRule {
* @param rand_seed The random seed for mutation. * @param rand_seed The random seed for mutation.
* @return The mutated trace. * @return The mutated trace.
*/ */
virtual ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) = 0; virtual ir::ScheduleDesc Apply(
const ir::ScheduleDesc& trace,
utils::LinearRandomEngine::StateType* rand_seed) = 0;
/** /**
* @brief Create a MutateRule with name. * @brief Create a MutateRule with name.
* @param name The name of mutate rule, consisting of lowercase letters and underscores * @param name The name of mutate rule, consisting of lowercase letters and
* underscores
* @return The created MutateRule. * @return The created MutateRule.
*/ */
static std::unique_ptr<MutateRule> Make(const std::string& name); static std::unique_ptr<MutateRule> Make(const std::string& name);
......
...@@ -20,13 +20,16 @@ namespace cinn { ...@@ -20,13 +20,16 @@ namespace cinn {
namespace auto_schedule { namespace auto_schedule {
/** /**
* The rule to mutate tile size, witch will modify the factors of the Split primitive. * The rule to mutate tile size, witch will modify the factors of the Split
* primitive.
*/ */
class MutateTileSize : public MutateRule { class MutateTileSize : public MutateRule {
public: public:
MutateTileSize() = default; MutateTileSize() = default;
ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) override; ir::ScheduleDesc Apply(
const ir::ScheduleDesc& trace,
utils::LinearRandomEngine::StateType* rand_seed) override;
}; };
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -36,7 +36,8 @@ using ::cinn::hlir::framework::NodeData; ...@@ -36,7 +36,8 @@ using ::cinn::hlir::framework::NodeData;
std::vector<TuneTask> TaskCreator::CreateTuneTaskOpLevel(Graph* graph) { std::vector<TuneTask> TaskCreator::CreateTuneTaskOpLevel(Graph* graph) {
std::vector<TuneTask> ret_tasks; std::vector<TuneTask> ret_tasks;
const std::vector<std::shared_ptr<Graph::Group>>* groups = &graph->fusion_groups; const std::vector<std::shared_ptr<Graph::Group>>* groups =
&graph->fusion_groups;
std::vector<std::shared_ptr<Graph::Group>> non_fused_groups; std::vector<std::shared_ptr<Graph::Group>> non_fused_groups;
// The input graph doesn't run Op Fusion // The input graph doesn't run Op Fusion
if (graph->fusion_groups.empty()) { if (graph->fusion_groups.empty()) {
...@@ -48,7 +49,7 @@ std::vector<TuneTask> TaskCreator::CreateTuneTaskOpLevel(Graph* graph) { ...@@ -48,7 +49,7 @@ std::vector<TuneTask> TaskCreator::CreateTuneTaskOpLevel(Graph* graph) {
for (const auto& sub_graph : *groups) { for (const auto& sub_graph : *groups) {
ret_tasks.emplace_back(TuneTask()); ret_tasks.emplace_back(TuneTask());
ret_tasks.back().subgraph = sub_graph; ret_tasks.back().subgraph = sub_graph;
ret_tasks.back().target = graph->target_; ret_tasks.back().target = graph->target_;
} }
return ret_tasks; return ret_tasks;
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册