未验证 提交 af127342 编写于 作者: 张经纬 提交者: GitHub

[CodeStyle][CINN] format cpp code via clang-format (#54961)

* fix clang-format

* 'fix_clang-format'

* fix remaining errors

* format

* empty commit, re-trigger all ci

* empty commit, re-trigger all ci

---------
Co-authored-by: NSigureMo <sigure.qaq@gmail.com>
上级 a7419ff5
......@@ -47,7 +47,8 @@ repos:
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
exclude: |
(?x)^(
paddle/utils/.*
paddle/utils/.*|
paddle/cinn/utils/registry.h
)$
# For Python files
- repo: https://github.com/psf/black.git
......
......@@ -58,26 +58,32 @@ void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) {
const ir::Load* load_expr = x->As<ir::Load>();
if (load_expr != nullptr) {
const ir::Tensor t = load_expr->tensor.as_tensor_ref();
sche_block->read_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices)));
sche_block->read_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices)));
return false;
}
const ir::Store* store_expr = x->As<ir::Store>();
if (store_expr != nullptr) {
const ir::Tensor t = store_expr->tensor.as_tensor_ref();
sche_block->write_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices)));
sche_block->write_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices)));
return false;
}
return false;
});
}
bool ContainsNodeType(ir::Expr expr, const std::unordered_set<ir::IrNodeTy>& node_types) {
std::set<ir::Expr> collection = ir::CollectIRNodesWithoutTensor(
expr, [&](const Expr* x) { return node_types.find(x->node_type()) != node_types.end(); });
bool ContainsNodeType(ir::Expr expr,
const std::unordered_set<ir::IrNodeTy>& node_types) {
std::set<ir::Expr> collection =
ir::CollectIRNodesWithoutTensor(expr, [&](const Expr* x) {
return node_types.find(x->node_type()) != node_types.end();
});
return !collection.empty();
}
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<ir::LoweredFunc>& lowered_funcs) {
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(
const std::vector<ir::LoweredFunc>& lowered_funcs) {
std::unordered_set<std::string> result;
for (const ir::LoweredFunc& func : lowered_funcs) {
for (const ir::Argument& arg : func->args) {
......@@ -90,18 +96,22 @@ std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<
}
bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) {
const ir::ScheduleBlock* sche_block = sche_block_realize.schedule_block.As<ir::ScheduleBlock>();
if (sche_block->write_buffers.size() != 1 || sche_block->read_buffers.empty()) {
const ir::ScheduleBlock* sche_block =
sche_block_realize.schedule_block.As<ir::ScheduleBlock>();
if (sche_block->write_buffers.size() != 1 ||
sche_block->read_buffers.empty()) {
return false;
}
const ir::Expr& write_buffer = sche_block->write_buffers[0].As<ir::_BufferRange_>()->buffer;
const ir::Expr& write_buffer =
sche_block->write_buffers[0].As<ir::_BufferRange_>()->buffer;
// Enumerate each read region, get the number of schedule block iter vars
// which are not used to index the read region
int total_unused_iter_vars = 0;
for (const ir::Expr& read_buffer_expr : sche_block->read_buffers) {
const ir::_BufferRange_* read_buffer = read_buffer_expr.As<ir::_BufferRange_>();
const ir::_BufferRange_* read_buffer =
read_buffer_expr.As<ir::_BufferRange_>();
// Skip the reduction buffer
if (read_buffer->buffer == write_buffer) {
continue;
......@@ -133,7 +143,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) {
return total_unused_iter_vars >= 1;
}
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body) {
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body) {
ir::ModuleExpr mod_expr(std::vector<ir::Expr>({body}));
ir::IRSchedule ir_sch(mod_expr);
......@@ -143,8 +155,10 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo
const std::string& buf_name = buf->name;
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
for (ir::Expr& e : all_block_realizes) {
const ir::ScheduleBlockRealize* sche_block_realize = e.As<ir::ScheduleBlockRealize>();
const std::string& sche_name = sche_block_realize->schedule_block.As<ir::ScheduleBlock>()->name;
const ir::ScheduleBlockRealize* sche_block_realize =
e.As<ir::ScheduleBlockRealize>();
const std::string& sche_name =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>()->name;
if (buf_name == "_" + sche_name) {
VLOG(6) << "Set local buffer for temp buffer " << buf_name;
ir_sch.SetBuffer(e, "local", true);
......@@ -159,14 +173,17 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo
#endif
// Get new temp bufs by analyzing.
std::vector<ir::Buffer> new_temp_bufs = lang::GetTempBuffers(old_func->args, updated_body);
ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(old_func->name, old_func->args, updated_body, new_temp_bufs);
std::vector<ir::Buffer> new_temp_bufs =
lang::GetTempBuffers(old_func->args, updated_body);
ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(
old_func->name, old_func->args, updated_body, new_temp_bufs);
#ifdef CINN_WITH_CUDA
if (target == common::DefaultNVGPUTarget()) {
new_func->PrepareCudaAxisInfoFromBody();
}
#endif
new_func = optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref();
new_func =
optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref();
new_func->PrepareBufferCastExprs(/*with_expr_gen_tensor = */ false);
return new_func;
......
......@@ -27,12 +27,14 @@ namespace auto_schedule {
void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block);
bool ContainsNodeType(ir::Expr expr, const std::unordered_set<ir::IrNodeTy>& node_types);
bool ContainsNodeType(ir::Expr expr,
const std::unordered_set<ir::IrNodeTy>& node_types);
/**
* Collects all input lowered_funcs and return names of all output arguments
*/
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<ir::LoweredFunc>& lowered_funcs);
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(
const std::vector<ir::LoweredFunc>& lowered_funcs);
/**
* Determine whether a schedule block needs multileveltiling
......@@ -42,7 +44,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize);
/**
* Update a LoweredFunc by regenerating related fields with a new function body
*/
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body);
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body);
} // namespace auto_schedule
} // namespace cinn
......@@ -50,7 +50,8 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) {
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ASSERT_FALSE(funcs.empty());
ir::Expr ast_expr = funcs[0]->body;
......@@ -65,8 +66,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) {
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
ASSERT_EQ(all_block_realizes.size(), 1UL);
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block);
/*
......@@ -113,7 +116,8 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) {
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = poly::CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("AddDiffShape", stages, {C}, {}, {}, nullptr, target, true);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"AddDiffShape", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: ";
......@@ -126,8 +130,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) {
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
ASSERT_EQ(all_block_realizes.size(), 1UL);
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block);
VLOG(6) << "ScheduleBlockRealize: ";
......@@ -164,7 +170,8 @@ TEST(AnalyzeIr, ContainsNodeType) {
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ASSERT_FALSE(funcs.empty());
ir::Expr ast_expr = funcs[0]->body;
......@@ -172,9 +179,12 @@ TEST(AnalyzeIr, ContainsNodeType) {
VLOG(6) << "Analyzing for Expr:";
VLOG(6) << ast_expr;
ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store}));
ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse}));
ASSERT_FALSE(ContainsNodeType(ast_expr, {ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum}));
ASSERT_TRUE(
ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store}));
ASSERT_TRUE(ContainsNodeType(ast_expr,
{ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse}));
ASSERT_FALSE(ContainsNodeType(ast_expr,
{ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum}));
}
} // namespace auto_schedule
......
......@@ -38,13 +38,17 @@
namespace cinn {
namespace auto_schedule {
AutoTuner::AutoTuner(const common::Target& target, hlir::framework::Graph* graph) : target_(target), graph_(graph) {}
AutoTuner::AutoTuner(const common::Target& target,
hlir::framework::Graph* graph)
: target_(target), graph_(graph) {}
void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler) {
void AutoTuner::Initialize(const Config& config,
hlir::framework::GraphCompiler* graph_compiler) {
// create builder, runner, and schedule measurer
builder_ = std::make_unique<SimpleBuilder>(graph_compiler);
runner_ = std::make_unique<SimpleRunner>(config.runner_repeat_times);
schedule_measurer_ = std::make_unique<ScheduleMeasurer>(builder_.get(), runner_.get());
schedule_measurer_ =
std::make_unique<ScheduleMeasurer>(builder_.get(), runner_.get());
// initialize database
database_ = std::move(Database::Make(config.database_config));
......@@ -53,29 +57,43 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler*
TaskCreator task_creator;
tasks_ = task_creator.CreateTuneTaskOpLevel(graph_);
const auto& dtype_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
const auto& shape_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
const auto& dtype_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph_->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target_);
op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target_);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto i = 0; i < tasks_.size(); ++i) {
auto&& task = tasks_[i];
task.Initialize(shape_dict, dtype_dict, op_lowerer_.get());
// Register the initial ModuleExpr corresponding to the task
task_registry->Regist(task.serialized_key, ir::ModuleExpr(task.GetLoweredFuncBodyExprs()));
VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n" << task.serialized_key;
task_registry->Regist(task.serialized_key,
ir::ModuleExpr(task.GetLoweredFuncBodyExprs()));
VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n"
<< task.serialized_key;
}
// create task optimizers
utils::LinearRandomEngine::StateType initial_seed = utils::LinearRandomEngine::GetDeviceRandomValue();
utils::LinearRandomEngine::StateType initial_seed =
utils::LinearRandomEngine::GetDeviceRandomValue();
task_optimizers_.resize(tasks_.size());
std::transform(tasks_.begin(), tasks_.end(), task_optimizers_.begin(), [&](TuneTask& task) {
std::transform(tasks_.begin(),
tasks_.end(),
task_optimizers_.begin(),
[&](TuneTask& task) {
return std::make_unique<TaskOptimizer>(
&task, schedule_measurer_.get(), database_.get(), utils::ForkRandomState(&initial_seed));
&task,
schedule_measurer_.get(),
database_.get(),
utils::ForkRandomState(&initial_seed));
});
// create task scheduler
task_scheduler_ = TaskScheduler::Make(tasks_, config.task_schedule_config, config.task_schedule_strategy);
task_scheduler_ = TaskScheduler::Make(
tasks_, config.task_schedule_config, config.task_schedule_strategy);
}
void PrintResult(std::shared_ptr<hlir::framework::Graph::Group> group) {
......@@ -127,7 +145,8 @@ void PrintResult(const TuningResult& result) {
TuningResult AutoTuner::Tune(const TuningOptions& options) {
CHECK_GT(options.num_tuning_rounds, 0) << "Invalid config";
VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds << ", tasks size=" << tasks_.size();
VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds
<< ", tasks size=" << tasks_.size();
TuningResult result;
result.subgraphs.resize(tasks_.size());
......
......@@ -49,7 +49,8 @@ class AutoTuner {
AutoTuner(const common::Target& target, hlir::framework::Graph* graph);
// Initialize tuner with specific config and auxiliary objects.
void Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler);
void Initialize(const Config& config,
hlir::framework::GraphCompiler* graph_compiler);
// Perform the tuning process and return the final result
TuningResult Tune(const TuningOptions& options);
......
......@@ -76,11 +76,13 @@ class TestAutoTuner : public ::testing::Test {
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph);
graph_compiler =
std::make_unique<GraphCompiler>(target, compiled_scope, graph);
tuner = std::make_unique<AutoTuner>(target, graph.get());
}
TuningResult InitializeAndTune(const AutoTuner::Config& config, const TuningOptions& options) {
TuningResult InitializeAndTune(const AutoTuner::Config& config,
const TuningOptions& options) {
tuner->Initialize(config, graph_compiler.get());
return tuner->Tune(options);
}
......@@ -108,7 +110,8 @@ class TestAutoTuner : public ::testing::Test {
VLOG(6) << "Print lowered_funcs before building";
VLOG(6) << compile_options.lowered_funcs[0][0];
VLOG(6) << compile_options.lowered_funcs[1][0];
auto runtime_program = graph_compiler->Build(compile_options).runtime_program;
auto runtime_program =
graph_compiler->Build(compile_options).runtime_program;
ASSERT_EQ(1, runtime_program->size());
runtime_program->Execute();
}
......
......@@ -28,7 +28,8 @@
namespace cinn {
namespace auto_schedule {
float ExprCostModel::Predict(const ir::ModuleExpr& sample, const common::Target& target) const {
float ExprCostModel::Predict(const ir::ModuleExpr& sample,
const common::Target& target) const {
if (trained_times_.load() == 0) {
return SearchState::NOT_INIT_COST;
}
......@@ -44,7 +45,8 @@ void ExprCostModel::Train(const std::vector<const ir::ModuleExpr*>& samples,
const common::Target& target) {
trained_times_.store(1);
size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels";
CHECK_EQ(total_size, labels.size())
<< "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) {
......@@ -61,7 +63,8 @@ void ExprCostModel::Update(const std::vector<const ir::ModuleExpr*>& samples,
const common::Target& target) {
++trained_times_;
size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels";
CHECK_EQ(total_size, labels.size())
<< "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) {
......
......@@ -29,7 +29,8 @@ namespace auto_schedule {
*/
class ExprCostModel : public XgbCostModel {
public:
virtual float Predict(const ir::ModuleExpr& sample, const common::Target& target) const;
virtual float Predict(const ir::ModuleExpr& sample,
const common::Target& target) const;
void Train(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target);
......
......@@ -49,7 +49,8 @@ Feature::Feature(const common::Target& target)
parent_indices_(1, -1) {}
std::vector<float> Feature::ToFixedSizeVector() {
std::vector<float> ret(LoopBlockFeature::kTotalSize + 1, 0); // LoopBlockFeature::kTotalSize plus 1 for target
std::vector<float> ret(LoopBlockFeature::kTotalSize + 1,
0); // LoopBlockFeature::kTotalSize plus 1 for target
if (target_ == common::DefaultNVGPUTarget()) {
ret[0] = 1;
......@@ -165,11 +166,17 @@ void Feature::IntoLoopBlock() {
current_loop_block_index_ = stack_encoded_feature_.size() - 1;
}
void Feature::ExitLoopBlock() { current_loop_block_index_ = parent_indices_[current_loop_block_index_]; }
void Feature::ExitLoopBlock() {
current_loop_block_index_ = parent_indices_[current_loop_block_index_];
}
LoopBlockFeature& Feature::CurrentLoopBlock() { return stack_encoded_feature_[current_loop_block_index_]; }
LoopBlockFeature& Feature::CurrentLoopBlock() {
return stack_encoded_feature_[current_loop_block_index_];
}
const LoopBlockFeature& Feature::CurrentLoopBlock() const { return stack_encoded_feature_[current_loop_block_index_]; }
const LoopBlockFeature& Feature::CurrentLoopBlock() const {
return stack_encoded_feature_[current_loop_block_index_];
}
} // namespace auto_schedule
} // namespace cinn
......@@ -24,10 +24,18 @@ namespace cinn {
namespace auto_schedule {
/* Loop feature enums */
enum class ForOptimizeFeatureEnum : int { kNone, kGpuBind, kParallel, kUnroll, kVectorize };
enum class ForOptimizeFeatureEnum : int {
kNone,
kGpuBind,
kParallel,
kUnroll,
kVectorize
};
/* function to scale feature numbers */
inline float slog(float x) { return x < 0 ? std::log2(-x + 1) : std::log2(x + 1); }
inline float slog(float x) {
return x < 0 ? std::log2(-x + 1) : std::log2(x + 1);
}
class LoopBlockFeature {
public:
......@@ -106,7 +114,9 @@ class LoopBlockFeature {
static constexpr int kThreadFeatureSize = 8;
static constexpr int kTotalSize = kArithSize + kMemSize + kReduceBroadcastSize + kOptApplySize + kThreadFeatureSize;
static constexpr int kTotalSize = kArithSize + kMemSize +
kReduceBroadcastSize + kOptApplySize +
kThreadFeatureSize;
/* Non-feature attributes, used to maintain during feature_extractor */
......@@ -158,10 +168,11 @@ class Feature {
// some_compute_3
// }
//
// We go through the code and push loops into stack, then the features are encoded as
// [loop_block_feature_0, loop_block_feature_1, loop_block_feature_2, loop_block_feature_3]
// where loop_block_feature_i stores the features of some_compute_i (such
// as number of arithmetic operations)
// We go through the code and push loops into stack, then the features are
// encoded as [loop_block_feature_0, loop_block_feature_1,
// loop_block_feature_2, loop_block_feature_3] where loop_block_feature_i
// stores the features of some_compute_i (such as number of arithmetic
// operations)
//
// loop_block_feature_0.num_sub_loops = 2
// loop_block_feature_1.num_sub_loops = 1
......
......@@ -47,7 +47,8 @@ FeatureExtractor::FeatureExtractor() {}
void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); }
Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, const common::Target &target) {
Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr,
const common::Target &target) {
feature_ = Feature(target);
for (const ir::Expr &e : mod_expr.GetExprs()) {
Visit(&e);
......@@ -87,7 +88,8 @@ NotVisitExprFields(_Tensor_)
#define VisitForDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \
if (x->type() == common::F32() || x->type() == common::F16() || \
x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \
} else { \
feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \
......@@ -120,8 +122,10 @@ VisitForDtypePattern(Let, other_call);
#define VisitForMultiOperandsDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += (x->operands().size() - 1); \
if (x->type() == common::F32() || x->type() == common::F16() || \
x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += \
(x->operands().size() - 1); \
} else { \
feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \
} \
......@@ -166,7 +170,8 @@ void FeatureExtractor::Visit(const For *x) {
LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock();
if (x->min.is_constant() && x->extent.is_constant()) {
loop_feature.loop_length = (x->extent.get_constant() - x->min.get_constant());
loop_feature.loop_length =
(x->extent.get_constant() - x->min.get_constant());
} else {
loop_feature.loop_length = -1; // -1 represents unknown
}
......@@ -223,13 +228,16 @@ void FeatureExtractor::Visit(const PolyFor *x) {
/* Visit for Reduce and Broadcast */
void FeatureExtractor::Visit(const Reduce *x) {
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) {
if (x->type() == common::F32() || x->type() == common::F16() ||
x->type() == common::F64()) {
switch (x->reduce_type) {
case Reduce::ReduceType::kSum:
feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes();
feature_.CurrentLoopBlock().float_reduce_sum_or_sub +=
x->type().lanes();
break;
case Reduce::ReduceType::kSub:
feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes();
feature_.CurrentLoopBlock().float_reduce_sum_or_sub +=
x->type().lanes();
break;
case Reduce::ReduceType::kDiv:
feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes();
......@@ -238,10 +246,12 @@ void FeatureExtractor::Visit(const Reduce *x) {
feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes();
break;
case Reduce::ReduceType::kMax:
feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes();
feature_.CurrentLoopBlock().float_reduce_max_or_min +=
x->type().lanes();
break;
case Reduce::ReduceType::kMin:
feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes();
feature_.CurrentLoopBlock().float_reduce_max_or_min +=
x->type().lanes();
break;
}
} else {
......
......@@ -49,7 +49,8 @@ TEST(FeatureExtractor, SimpleAssign) {
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr to test: " << ast_expr;
......@@ -62,7 +63,8 @@ TEST(FeatureExtractor, SimpleAssign) {
std::vector<float> to_check = feature.ToFixedSizeVector();
ASSERT_EQ(to_check.size(), static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
ASSERT_EQ(to_check.size(),
static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
VLOG(6) << "Feature data before slog:";
for (size_t i = 0; i < to_check.size(); ++i) {
VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1);
......@@ -77,9 +79,11 @@ TEST(FeatureExtractor, SimpleAssign) {
ASSERT_EQ(to_check[0], 0);
#endif
// mem_read
ASSERT_EQ(to_check[17], slog(M.get_constant() * N.get_constant())); // mem_read
ASSERT_EQ(to_check[17],
slog(M.get_constant() * N.get_constant())); // mem_read
// mem_write
ASSERT_EQ(to_check[18], slog(M.get_constant() * N.get_constant())); // mem_write
ASSERT_EQ(to_check[18],
slog(M.get_constant() * N.get_constant())); // mem_write
// non-opt loops, including root block
ASSERT_EQ(to_check[29], slog(3));
}
......@@ -101,10 +105,13 @@ TEST(FeatureExtractor, MatrixMultiply) {
ir::Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = lang::Compute(
{M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C");
{M, N},
[&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = poly::CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true);
std::vector<Expr> vec_ast{funcs[0]->body};
ir::ModuleExpr mod_expr(vec_ast);
......@@ -121,7 +128,8 @@ TEST(FeatureExtractor, MatrixMultiply) {
std::vector<float> to_check = feature.ToFixedSizeVector();
ASSERT_EQ(to_check.size(), static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
ASSERT_EQ(to_check.size(),
static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
std::unordered_set<size_t> non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37};
for (size_t i = 0; i < to_check.size(); ++i) {
VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1);
......
......@@ -57,7 +57,8 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) {
Dtype* py_data = static_cast<Dtype*>(ret.mutable_data());
for (size_t i = 0; i < vec.size(); ++i) {
assert(vec[i].size() == shape[1] && "Sub vectors must have same size in VectorToNumpy");
assert(vec[i].size() == shape[1] &&
"Sub vectors must have same size in VectorToNumpy");
memcpy(py_data + (shape[1] * i), vec[i].data(), shape[1] * sizeof(Dtype));
}
return ret;
......@@ -71,19 +72,23 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) {
void AddDistPkgToPythonSysPath() {
pybind11::module sys_py_mod = pybind11::module::import("sys");
// short version such as "3.7", "3.8", ...
std::string py_short_version = sys_py_mod.attr("version").cast<std::string>().substr(0, 3);
std::string py_short_version =
sys_py_mod.attr("version").cast<std::string>().substr(0, 3);
std::string site_pkg_str = "/usr/local/lib/python" + py_short_version + "/dist-packages";
std::string site_pkg_str =
"/usr/local/lib/python" + py_short_version + "/dist-packages";
sys_py_mod.attr("path").attr("append")(site_pkg_str);
// TODO(zhhsplendid): warning to users if setuptools hasn't been installed
DIR* site_pkg_dir = opendir(site_pkg_str.c_str());
if (site_pkg_dir != nullptr) {
std::regex setuptool_regex("setuptools-.*-py" + py_short_version + "\\.egg");
std::regex setuptool_regex("setuptools-.*-py" + py_short_version +
"\\.egg");
struct dirent* entry = nullptr;
while ((entry = readdir(site_pkg_dir)) != nullptr) {
if (std::regex_match(entry->d_name, setuptool_regex)) {
sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" + entry->d_name);
sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" +
entry->d_name);
}
}
closedir(site_pkg_dir);
......@@ -100,36 +105,45 @@ XgbCostModel::XgbCostModel() {
xgb_booster_ = xgb_module_.attr("Booster")();
}
void XgbCostModel::Train(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) {
void XgbCostModel::Train(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) {
update_samples_ = samples;
update_labels_ = labels;
pybind11::array np_samples = VectorToNumpy<float>(samples);
pybind11::array np_labels = VectorToNumpy<float>(labels);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels);
xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
xgb_booster_ = xgb_module_.attr("train")(
pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
}
std::vector<float> XgbCostModel::Predict(const std::vector<std::vector<float>>& samples) const {
std::vector<float> XgbCostModel::Predict(
const std::vector<std::vector<float>>& samples) const {
pybind11::array np_samples = VectorToNumpy<float>(samples);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples);
pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix);
return py_result.cast<std::vector<float>>();
}
void XgbCostModel::Update(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) {
void XgbCostModel::Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) {
update_samples_.insert(update_samples_.end(), samples.begin(), samples.end());
update_labels_.insert(update_labels_.end(), labels.begin(), labels.end());
pybind11::array np_samples = VectorToNumpy<float>(update_samples_);
pybind11::array np_labels = VectorToNumpy<float>(update_labels_);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels);
xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
xgb_booster_ = xgb_module_.attr("train")(
pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
}
void XgbCostModel::Save(const std::string& path) { xgb_booster_.attr("save_model")(pybind11::str(path)); }
void XgbCostModel::Save(const std::string& path) {
xgb_booster_.attr("save_model")(pybind11::str(path));
}
void XgbCostModel::Load(const std::string& path) { xgb_booster_.attr("load_model")(pybind11::str(path)); }
void XgbCostModel::Load(const std::string& path) {
xgb_booster_.attr("load_model")(pybind11::str(path));
}
} // namespace auto_schedule
} // namespace cinn
......@@ -47,11 +47,14 @@ class XgbCostModel : public CostModel {
XgbCostModel();
~XgbCostModel() = default;
void Train(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) override;
void Train(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) override;
std::vector<float> Predict(const std::vector<std::vector<float>>& samples) const override;
std::vector<float> Predict(
const std::vector<std::vector<float>>& samples) const override;
void Update(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) override;
void Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) override;
void Save(const std::string& path) override;
......
......@@ -34,7 +34,8 @@ TEST(CostModel, Basic) {
int batch_size = 16;
int feature_size = 8;
std::vector<float> labels(batch_size, 1.0);
std::vector<std::vector<float>> samples(batch_size, std::vector<float>(feature_size));
std::vector<std::vector<float>> samples(batch_size,
std::vector<float>(feature_size));
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < feature_size; ++j) {
samples[i][j] = rand() % 10;
......
......@@ -26,7 +26,8 @@
namespace cinn {
namespace auto_schedule {
bool TuningRecord::Compare::operator()(const TuningRecord& lhs, const TuningRecord& rhs) const {
bool TuningRecord::Compare::operator()(const TuningRecord& lhs,
const TuningRecord& rhs) const {
return lhs.execution_cost < rhs.execution_cost;
}
......@@ -39,15 +40,18 @@ proto::TuningRecord TuningRecord::ToProto() const {
return record_proto;
}
Database::Database(int capacity_per_task) : capacity_per_task_(capacity_per_task) {
CHECK_GT(capacity_per_task_, 0) << "capacity_per_task_ should be greater than 0";
Database::Database(int capacity_per_task)
: capacity_per_task_(capacity_per_task) {
CHECK_GT(capacity_per_task_, 0)
<< "capacity_per_task_ should be greater than 0";
}
std::unique_ptr<Database> Database::Make(const DatabaseConfig& config) {
if (config.type == DatabaseType::kMemory) {
return std::make_unique<Database>(config.capacity_per_task);
} else if (config.type == DatabaseType::kJSONFile) {
return std::make_unique<JSONFileDatabase>(config.capacity_per_task, config.record_file_path, true);
return std::make_unique<JSONFileDatabase>(
config.capacity_per_task, config.record_file_path, true);
}
LOG(FATAL) << "Unimplemented database type.";
......@@ -81,13 +85,16 @@ std::vector<TuningRecord> Database::LookUp(const std::string& task_key) {
return results;
}
std::vector<TuningRecord> Database::GetTopK(const std::string& task_key, int k) {
std::vector<TuningRecord> Database::GetTopK(const std::string& task_key,
int k) {
auto fit = key2record_.find(task_key);
if (fit == key2record_.end() || k <= 0) {
return {};
}
if (k > capacity_per_task_) {
LOG(WARNING) << "Top k=" << k << " is greater than the capacity, will adjust k=" << capacity_per_task_;
LOG(WARNING) << "Top k=" << k
<< " is greater than the capacity, will adjust k="
<< capacity_per_task_;
k = capacity_per_task_;
}
......@@ -103,8 +110,10 @@ std::vector<TuningRecord> Database::GetTopK(const std::string& task_key, int k)
}
size_t Database::Size() {
auto res =
std::accumulate(key2record_.begin(), key2record_.end(), size_t(0), [](size_t res, const auto& kv) -> size_t {
auto res = std::accumulate(key2record_.begin(),
key2record_.end(),
size_t(0),
[](size_t res, const auto& kv) -> size_t {
return std::move(res) + kv.second.size();
});
return res;
......
......@@ -39,7 +39,9 @@ struct TuningRecord {
predicted_cost(record.predicted_cost()),
trace(record.trace()),
execution_cost(record.execution_cost()) {}
TuningRecord(const std::string& task_key, const SearchState& state, double execution_cost)
TuningRecord(const std::string& task_key,
const SearchState& state,
double execution_cost)
: task_key(task_key),
predicted_cost(state->predicted_cost),
trace(state->ir_schedule.GetTraceDesc().ToProto()),
......@@ -63,10 +65,10 @@ struct DatabaseConfig {
std::string record_file_path = "/tmp/tuning_record.json";
};
// A database supports insert or lookup historial tuning result with specified traits.
// It can be implemented with a concrete storage to save/load underlying data,
// such as memory, file, database server and so on, this base class can be regarded as
// one using memory as its underlying storage medium.
// A database supports insert or lookup historial tuning result with specified
// traits. It can be implemented with a concrete storage to save/load underlying
// data, such as memory, file, database server and so on, this base class can be
// regarded as one using memory as its underlying storage medium.
class Database {
public:
explicit Database(int capacity_per_task);
......@@ -93,7 +95,9 @@ class Database {
void Insert(const TuningRecord& record);
// map task_key to its records
std::unordered_map<std::string, std::multiset<TuningRecord, TuningRecord::Compare>> key2record_;
std::unordered_map<std::string,
std::multiset<TuningRecord, TuningRecord::Compare>>
key2record_;
// the max number of candidates stored
const int capacity_per_task_;
};
......
......@@ -57,8 +57,10 @@ TEST_F(TestDatabase, GetTopK) {
ASSERT_TRUE(test_db.GetTopK("k5", 2).empty());
ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1);
test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0));
test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0));
test_db.AddRecord(
TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0));
test_db.AddRecord(
TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0));
auto records = test_db.GetTopK("k4", 3);
ASSERT_EQ(records.size(), 2);
......
......@@ -35,7 +35,8 @@ void AppendLineToFile(const std::string& file_path, const std::string& line) {
}
// read lines from a json file
std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool allow_new_file) {
std::vector<std::string> ReadLinesFromFile(const std::string& file_path,
bool allow_new_file) {
std::ifstream is(file_path);
if (is.good()) {
std::vector<std::string> json_strs;
......@@ -51,20 +52,26 @@ std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool al
return {};
}
JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file)
JSONFileDatabase::JSONFileDatabase(int capacity_per_task,
const std::string& record_file_path,
bool allow_new_file)
: Database(capacity_per_task), record_file_path_(record_file_path) {
VLOG(3) << "Auto schedule will save/load tuning records on file:" << record_file_path;
VLOG(3) << "Auto schedule will save/load tuning records on file:"
<< record_file_path;
auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file);
std::vector<cinn::auto_schedule::proto::TuningRecord> all_records_proto(json_lines.size());
std::vector<cinn::auto_schedule::proto::TuningRecord> all_records_proto(
json_lines.size());
// convert JSON string to proto object
auto worker_fn = [this, &json_lines, &all_records_proto](int index) {
cinn::auto_schedule::proto::TuningRecord record_proto;
auto status = google::protobuf::util::JsonStringToMessage(json_lines[index], &record_proto);
auto status = google::protobuf::util::JsonStringToMessage(json_lines[index],
&record_proto);
CHECK(status.ok()) << "Failed to parse JSON: " << json_lines[index];
all_records_proto[index].Swap(&record_proto);
};
utils::parallel_run(worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1);
utils::parallel_run(
worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
......@@ -81,8 +88,10 @@ JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& rec
std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) {
proto::TuningRecord record_proto = record.ToProto();
std::string json_string;
auto status = google::protobuf::util::MessageToJsonString(record_proto, &json_string);
CHECK(status.ok()) << "Failed to serialize record to JSON, task key = " << record.task_key;
auto status =
google::protobuf::util::MessageToJsonString(record_proto, &json_string);
CHECK(status.ok()) << "Failed to serialize record to JSON, task key = "
<< record.task_key;
VLOG(4) << "json_string = \n" << json_string;
return json_string;
......
......@@ -19,16 +19,20 @@
namespace cinn {
namespace auto_schedule {
// JSONFileDatabase is a database implemented by JSON file to save/load underlying data.
// JSONFileDatabase is a database implemented by JSON file to save/load
// underlying data.
class JSONFileDatabase : public Database {
public:
/*!
* \brief Build a JSONFileDatabase object from a json file.
* \param capacity_per_task The max number of candidates stored.
* \param record_file_path The path of the json file.
* \param allow_new_file Whether to create new file when the given path is not found.
* \param allow_new_file Whether to create new file when the given path is not
* found.
*/
JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file);
JSONFileDatabase(int capacity_per_task,
const std::string& record_file_path,
bool allow_new_file);
~JSONFileDatabase() = default;
// convert a TuningRecord object to string in JSON format
......@@ -46,7 +50,8 @@ class JSONFileDatabase : public Database {
void AppendLineToFile(const std::string& file_path, const std::string& line);
// read lines from a json file
std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool allow_new_file = true);
std::vector<std::string> ReadLinesFromFile(const std::string& file_path,
bool allow_new_file = true);
} // namespace auto_schedule
} // namespace cinn
......@@ -31,7 +31,8 @@ namespace cinn {
namespace auto_schedule {
// Return lowerd ir AST for example functions used in this test
std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape, const Target& target) {
std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape,
const Target& target) {
CHECK(shape.size() == 2) << "shape should be 2";
std::vector<Expr> domain;
for (auto i = 0; i < shape.size(); ++i) {
......@@ -46,11 +47,13 @@ std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape, const T
C = Compute(
domain, [&B](Var i, Var j) { return B(i, j); }, "C");
return cinn::lang::LowerVec("test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
return cinn::lang::LowerVec(
"test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
}
// Create a new IRSchedule with copied ir::LoweredFunc AST
ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs, const std::string& task_key) {
ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
const std::string& task_key) {
std::vector<Expr> exprs;
for (auto&& func : lowered_funcs) {
exprs.emplace_back(optim::IRCopy(func->body));
......@@ -63,7 +66,9 @@ ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
class TestJSONFileDatabase : public ::testing::Test {
public:
TestJSONFileDatabase() : record_file_path("/tmp/test_record.json"), test_db(2, record_file_path, true) {}
TestJSONFileDatabase()
: record_file_path("/tmp/test_record.json"),
test_db(2, record_file_path, true) {}
void SetUp() override { lowered_funcs = LowerCompute({32, 32}, target); }
......@@ -97,14 +102,19 @@ TEST_F(TestJSONFileDatabase, Serialize) {
TuningRecord record1("test", SearchState(std::move(ir_sch), 2.0), 1.0);
std::string str = test_db.RecordToJSON(record1);
VLOG(3) << "RecordToJSON: " << str;
// Because the serialization of protobuf does not guarantee the order, we give all possible results.
// Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std::string case1 =
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":"
"{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":"
"\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}";
std::string case2 =
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":"
"{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":"
"\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}";
EXPECT_EQ(true, str == case1 || str == case2);
}
......@@ -114,32 +124,48 @@ TEST_F(TestJSONFileDatabase, SaveLoad) {
auto fused1 = ir_sch1.Fuse("B", {0, 1});
ir::IRSchedule ir_sch2 = MakeIRSchedule(lowered_funcs, "k2");
test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0));
test_db.AddRecord(TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0));
test_db.AddRecord(
TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0));
test_db.AddRecord(
TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0));
std::vector<std::string> strs = ReadLinesFromFile(record_file_path);
ASSERT_EQ(strs.size(), 2);
// Because the serialization of protobuf does not guarantee the order, we give all possible results.
// Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std::string case1 =
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":"
"{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":"
"\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}";
std::string case2 =
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":"
"{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":"
"\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}";
EXPECT_EQ(true, strs[0] == case1 || strs[0] == case2);
EXPECT_EQ(strs[1], "{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,\"trace\":{}}");
EXPECT_EQ(strs[1],
"{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,"
"\"trace\":{}}");
}
TEST_F(TestJSONFileDatabase, Basic) {
test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0));
test_db.AddRecord(TuningRecord(
"k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0));
ASSERT_EQ(test_db.Size(), 6);
auto records = test_db.LookUp("k3");
......@@ -152,15 +178,24 @@ TEST_F(TestJSONFileDatabase, Basic) {
}
TEST_F(TestJSONFileDatabase, GetTopK) {
test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0));
test_db.AddRecord(TuningRecord(
"k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0));
auto records = test_db.GetTopK("k4", 3);
ASSERT_EQ(records.size(), 2);
......@@ -171,8 +206,10 @@ TEST_F(TestJSONFileDatabase, GetTopK) {
TEST_F(TestJSONFileDatabase, Reload) {
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "k1");
auto fused = ir_sch.Fuse("B", {0, 1});
test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(
TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
auto records = test_db.LookUp("k1");
ASSERT_EQ(records.size(), 1);
......@@ -184,11 +221,13 @@ TEST_F(TestJSONFileDatabase, Reload) {
EXPECT_EQ(records[0].execution_cost, loaded_records[0].execution_cost);
EXPECT_EQ(records[0].predicted_cost, loaded_records[0].predicted_cost);
// check the equality of trace info between original TuningRecord and the loaded TuningRecord
// check the equality of trace info between original TuningRecord and the
// loaded TuningRecord
const auto& lhs_trace = records[0].trace;
const auto& rhs_trace = loaded_records[0].trace;
google::protobuf::util::MessageDifferencer dif;
static const google::protobuf::Descriptor* descriptor = cinn::ir::proto::ScheduleDesc_Step::descriptor();
static const google::protobuf::Descriptor* descriptor =
cinn::ir::proto::ScheduleDesc_Step::descriptor();
dif.TreatAsSet(descriptor->FindFieldByName("attrs"));
EXPECT_TRUE(dif.Compare(lhs_trace, rhs_trace));
......
......@@ -53,7 +53,8 @@ struct MeasureResult {
// The result of building with input schedule
struct BuildResult {
// The scope that owns detail compilation infos of parameters in the runtime program
// The scope that owns detail compilation infos of parameters in the runtime
// program
const hlir::framework::Scope* compiled_scope;
// The executable program
std::unique_ptr<hlir::framework::Program> runtime_program;
......@@ -68,11 +69,13 @@ class ScheduleBuilder {
virtual BuildResult Build(const MeasureInput& input) = 0;
};
// This interface defines how to run the built result. Like above ScheduleBuilder,
// a runner shoule be implemented with not bound to a specific task.
// This interface defines how to run the built result. Like above
// ScheduleBuilder, a runner shoule be implemented with not bound to a specific
// task.
class ScheduleRunner {
public:
virtual MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) = 0;
virtual MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) = 0;
};
} // namespace auto_schedule
......
......@@ -68,10 +68,15 @@ class TestMeasurer : public ::testing::Test {
graph_compiler = std::make_unique<GraphCompiler>(target, scope, graph);
TaskCreator task_creator;
tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target);
const auto& dtype_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>(
"infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
inputs.reserve(tasks.size());
for (int i = 0; i < tasks.size(); ++i) {
auto* task = &tasks[i];
......@@ -95,13 +100,17 @@ class ThrowExceptionRunner : public ScheduleRunner {
struct Exception : public std::exception {
const char* what() const throw() { return "RunError"; }
};
MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override { throw Exception(); }
MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) override {
throw Exception();
}
};
TEST_F(TestMeasurer, Basic) {
auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get());
auto runner = std::make_unique<SimpleRunner>(1);
auto measurer = std::make_unique<ScheduleMeasurer>(builder.get(), runner.get());
auto measurer =
std::make_unique<ScheduleMeasurer>(builder.get(), runner.get());
std::vector<MeasureResult> results = measurer->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size());
}
......@@ -111,13 +120,16 @@ TEST_F(TestMeasurer, CatchException) {
auto runner = std::make_unique<SimpleRunner>(1);
auto throw_builder = std::make_unique<ThrowExceptionBuilder>();
auto throw_runner = std::make_unique<ThrowExceptionRunner>();
auto measurer_with_build_error = std::make_unique<ScheduleMeasurer>(throw_builder.get(), runner.get(), 2);
std::vector<MeasureResult> results = measurer_with_build_error->Measure(inputs);
auto measurer_with_build_error =
std::make_unique<ScheduleMeasurer>(throw_builder.get(), runner.get(), 2);
std::vector<MeasureResult> results =
measurer_with_build_error->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size());
EXPECT_EQ(results[0].error_msg, "Build failed, error: BuildError\n");
// TODO(CtfGo): test parallel build after we support thread-safe compilation
auto measurer_with_run_error = std::make_unique<ScheduleMeasurer>(builder.get(), throw_runner.get(), 1);
auto measurer_with_run_error =
std::make_unique<ScheduleMeasurer>(builder.get(), throw_runner.get(), 1);
results = measurer_with_run_error->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size());
EXPECT_EQ(results[0].error_msg, "Run failed, error: RunError\n");
......
......@@ -21,10 +21,13 @@
namespace cinn {
namespace auto_schedule {
ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads)
ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder,
ScheduleRunner* runner,
int num_threads)
: builder_(builder), runner_(runner), num_threads_(num_threads) {}
std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureInput>& inputs) {
std::vector<MeasureResult> ScheduleMeasurer::Measure(
const std::vector<MeasureInput>& inputs) {
if (inputs.empty()) {
LOG(WARNING) << "inputs is empty";
return {};
......@@ -33,20 +36,24 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn
std::vector<MeasureResult> results(inputs.size());
// define how to build a candidate with the specified index
auto build_fn = [builder = builder_, &inputs, &build_results, &results](int index) {
auto build_fn =
[builder = builder_, &inputs, &build_results, &results](int index) {
VLOG(6) << "Build candidate index: " << index;
auto m_start = std::chrono::steady_clock::now();
try {
build_results[index] = builder->Build(inputs[index]);
} catch (std::exception& e) {
results[index].error_msg = utils::StringFormat("Build failed, error: %s\n", e.what());
results[index].error_msg =
utils::StringFormat("Build failed, error: %s\n", e.what());
}
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - m_start);
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - m_start);
results[index].elapsed_time += static_cast<double>(time_span.count());
};
// define how to run a candidate with the specified index
auto run_fn = [runner = runner_, &inputs, &build_results, &results](int index) {
auto run_fn =
[runner = runner_, &inputs, &build_results, &results](int index) {
VLOG(6) << "Run candidate index: " << index;
auto m_start = std::chrono::steady_clock::now();
try {
......@@ -55,9 +62,11 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn
results[index] = runner->Run(inputs[index], build_results[index]);
}
} catch (std::exception& e) {
results[index].error_msg = utils::StringFormat("Run failed, error: %s\n", e.what());
results[index].error_msg =
utils::StringFormat("Run failed, error: %s\n", e.what());
}
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - m_start);
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - m_start);
results[index].elapsed_time += static_cast<double>(time_span.count());
};
......@@ -66,8 +75,10 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn
build_fn(index);
run_fn(index);
};
// default num_threads_ is 1 and in that case it will perform all measurements sequentially inplace.
utils::parallel_run(measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_);
// default num_threads_ is 1 and in that case it will perform all measurements
// sequentially inplace.
utils::parallel_run(
measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_);
VLOG(4) << "Measure " << inputs.size() << " candidates";
return results;
......
......@@ -25,7 +25,9 @@ namespace auto_schedule {
// which are building the input schedules and running the generated codes.
class ScheduleMeasurer {
public:
ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads = 1);
ScheduleMeasurer(ScheduleBuilder* builder,
ScheduleRunner* runner,
int num_threads = 1);
// Measure a batch of inputs and return all results once.
std::vector<MeasureResult> Measure(const std::vector<MeasureInput>& inputs);
......
......@@ -19,17 +19,21 @@ namespace auto_schedule {
using hlir::framework::GraphCompiler;
SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler) : graph_compiler_(graph_compiler) {}
SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler)
: graph_compiler_(graph_compiler) {}
BuildResult SimpleBuilder::Build(const MeasureInput& input) {
CHECK_NE(graph_compiler_, static_cast<GraphCompiler*>(nullptr)) << "empty handle to GraphCompiler";
CHECK_NE(graph_compiler_, static_cast<GraphCompiler*>(nullptr))
<< "empty handle to GraphCompiler";
GraphCompiler::CompileOptions compile_options;
compile_options.groups.emplace_back(input.task->subgraph);
compile_options.lowered_funcs.emplace_back(input.lowered_funcs);
compile_options.remove_unused_variables = false;
VLOG(5) << "call GraphCompiler to Build with Graph::Group size=" << compile_options.groups.size()
<< ", lowered_funcs group size=" << compile_options.lowered_funcs.size();
GraphCompiler::CompilationResult compiled_result = graph_compiler_->Build(compile_options);
VLOG(5) << "call GraphCompiler to Build with Graph::Group size="
<< compile_options.groups.size() << ", lowered_funcs group size="
<< compile_options.lowered_funcs.size();
GraphCompiler::CompilationResult compiled_result =
graph_compiler_->Build(compile_options);
BuildResult build_result;
build_result.compiled_scope = graph_compiler_->GetScope().get();
......
......@@ -35,7 +35,8 @@ using hlir::framework::Tensor;
// Parameters that needs to be initialized to 0.
// Key is the Op name, and value is the index of the input parameter in the Op.
static const std::unordered_map<std::string, std::vector<int>> kInitWithZeroParams = {
static const std::unordered_map<std::string, std::vector<int>>
kInitWithZeroParams = {
{"lookup_table", {1}},
{"gather", {1}},
{"gather_nd", {1}},
......@@ -44,38 +45,53 @@ static const std::unordered_map<std::string, std::vector<int>> kInitWithZeroPara
};
// Generate random value and populate them to the output address of memory
static void PopulateRandomValue(const common::Type& type, const int numel, void* raw_ptr) {
static void PopulateRandomValue(const common::Type& type,
const int numel,
void* raw_ptr) {
std::random_device seed;
std::default_random_engine engine(seed());
if (type == common::Bool()) {
auto* fmt_ptr = reinterpret_cast<bool*>(raw_ptr);
std::bernoulli_distribution dist(0.5);
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::I32()) {
auto* fmt_ptr = reinterpret_cast<int*>(raw_ptr);
std::uniform_int_distribution<int> dist(std::numeric_limits<int>::min(), std::numeric_limits<int>::max());
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
std::uniform_int_distribution<int> dist(std::numeric_limits<int>::min(),
std::numeric_limits<int>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::I64()) {
auto* fmt_ptr = reinterpret_cast<int64_t*>(raw_ptr);
std::uniform_int_distribution<int64_t> dist(std::numeric_limits<int64_t>::min(),
std::uniform_int_distribution<int64_t> dist(
std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max());
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::F32()) {
auto* fmt_ptr = reinterpret_cast<float*>(raw_ptr);
std::uniform_real_distribution<float> dist(std::numeric_limits<float>::min(), std::numeric_limits<float>::max());
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
std::uniform_real_distribution<float> dist(
std::numeric_limits<float>::min(), std::numeric_limits<float>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else {
CHECK_EQ(type.bytes(), 8) << "Unsupported type: " << type << ", type.bytes = " << type.bytes();
CHECK_EQ(type.bytes(), 8)
<< "Unsupported type: " << type << ", type.bytes = " << type.bytes();
auto* fmt_ptr = reinterpret_cast<uint8_t*>(raw_ptr);
std::uniform_int_distribution<uint8_t> dist(std::numeric_limits<uint8_t>::min(),
std::uniform_int_distribution<uint8_t> dist(
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
}
}
// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize the tensor with random value.
static void InitTensorData(Tensor tensor, const common::Target& target, bool init_with_zero) {
// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize
// the tensor with random value.
static void InitTensorData(Tensor tensor,
const common::Target& target,
bool init_with_zero) {
int mem_size = tensor->shape().numel() * tensor->type().bytes();
auto* tensor_data = tensor->mutable_data(target, tensor->type());
#ifdef CINN_WITH_CUDA
......@@ -101,9 +117,11 @@ static void InitTensorData(Tensor tensor, const common::Target& target, bool ini
// Find all parameter names in the task corresponding to the MeasureInput
// that need to be initialized to 0 when measuring.
static std::unordered_set<std::string> ParamsNeedInitWithZero(const MeasureInput& input) {
static std::unordered_set<std::string> ParamsNeedInitWithZero(
const MeasureInput& input) {
std::unordered_set<std::string> res;
std::vector<hlir::framework::Node*> nodes = input.task->subgraph->CollectNodes();
std::vector<hlir::framework::Node*> nodes =
input.task->subgraph->CollectNodes();
for (auto* node : nodes) {
if (kInitWithZeroParams.count(node->op()->name) != 0) {
std::vector<int> param_idxs = kInitWithZeroParams.at(node->op()->name);
......@@ -111,7 +129,8 @@ static std::unordered_set<std::string> ParamsNeedInitWithZero(const MeasureInput
for (int param_idx : param_idxs) {
CHECK_GT(inlinks.size(), param_idx);
auto& edge = inlinks.at(param_idx);
std::string param_name = edge->source()->as<hlir::framework::NodeData>()->id();
std::string param_name =
edge->source()->as<hlir::framework::NodeData>()->id();
VLOG(6) << "param needs to be init with 0: " << param_name;
res.insert(param_name);
}
......@@ -128,7 +147,8 @@ SimpleRunner::SimpleRunner(int repeat_times) : repeat_times_(repeat_times) {
// Prepare execution arguments of all instructions to run, a argument
// may be obtained from the input of measurement or allocating new buffer
// with random value.
std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureInput& input,
std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(
const MeasureInput& input,
const BuildResult& build_result,
hlir::framework::Scope* temp_scope) {
std::map<std::string, cinn_pod_value_t> result;
......@@ -138,7 +158,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI
const auto* compiled_scope = build_result.compiled_scope;
const auto& instructions = build_result.runtime_program->GetRunInstructions();
std::unordered_set<std::string> params_need_init_with_zero = ParamsNeedInitWithZero(input);
std::unordered_set<std::string> params_need_init_with_zero =
ParamsNeedInitWithZero(input);
auto fill_arg_fn = [&](const std::string& param) {
VLOG(6) << "Filling argument:" << param;
......@@ -169,7 +190,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI
temp_tensor->Resize(compiled_tensor->shape());
temp_tensor->set_type(compiled_tensor->type());
temp_tensor->mutable_data(target, compiled_tensor->type());
InitTensorData(temp_tensor, target, params_need_init_with_zero.count(param) != 0);
InitTensorData(
temp_tensor, target, params_need_init_with_zero.count(param) != 0);
result.emplace(param, temp_tensor->buffer());
};
......@@ -186,7 +208,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI
return result;
}
MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& build_result) {
MeasureResult SimpleRunner::Run(const MeasureInput& input,
const BuildResult& build_result) {
MeasureResult result;
auto t_start = std::chrono::steady_clock::now();
// prepare execution arguments
......@@ -209,16 +232,18 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu
CUDA_CALL(cudaDeviceSynchronize());
}
#endif
auto time_span =
std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - run_start);
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - run_start);
auto cost_avg = static_cast<double>(time_span.count()) / repeat_times_;
result.execution_cost += cost_avg;
}
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - t_start);
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - t_start);
result.elapsed_time = static_cast<double>(time_span.count());
VLOG(4) << "A measurement done:repeat_times[" << repeat_times_ << "]total_elapsed_time[" << result.elapsed_time
VLOG(4) << "A measurement done:repeat_times[" << repeat_times_
<< "]total_elapsed_time[" << result.elapsed_time
<< "]us,execution_cost[" << result.execution_cost << "]us";
return result;
}
......
......@@ -26,10 +26,12 @@ class SimpleRunner : public ScheduleRunner {
public:
SimpleRunner(int repeat_times);
MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override;
MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) override;
private:
std::map<std::string, cinn_pod_value_t> PrepareArgs(const MeasureInput& input,
std::map<std::string, cinn_pod_value_t> PrepareArgs(
const MeasureInput& input,
const BuildResult& build_result,
hlir::framework::Scope* temp_scope);
......
......@@ -56,7 +56,8 @@ class TestSimpleRunner : public ::testing::Test {
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph);
graph_compiler =
std::make_unique<GraphCompiler>(target, compiled_scope, graph);
auto runtime_program = graph_compiler->Build();
const auto& instructions = runtime_program->GetRunInstructions();
ASSERT_EQ(1, instructions.size());
......@@ -115,11 +116,15 @@ TEST_F(TestSimpleRunner, TimeMeasured) {
BuildResult build_result;
build_result.compiled_scope = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions;
instructions.emplace_back(
new Instruction(common::DefaultHostTarget(), nullptr, {}, {"empty_placeholder"}, "sleep_fn"));
instructions.emplace_back(new Instruction(common::DefaultHostTarget(),
nullptr,
{},
{"empty_placeholder"},
"sleep_fn"));
instructions.back()->SetLoweredFunc(reinterpret_cast<void*>(sleep_fn));
instructions.back()->Finalize();
build_result.runtime_program.reset(new hlir::framework::Program(nullptr, std::move(instructions)));
build_result.runtime_program.reset(
new hlir::framework::Program(nullptr, std::move(instructions)));
// to skip the condition check of params in Instruction::PreparePodArgs
std::map<std::string, cinn_pod_value_t> preset_args;
......
......@@ -22,10 +22,12 @@
namespace cinn {
namespace auto_schedule {
int ExtractNumThreads(const ir::IRSchedule& ir_schedule, const std::string& bind_axis) {
int ExtractNumThreads(const ir::IRSchedule& ir_schedule,
const std::string& bind_axis) {
const ir::ScheduleDesc& trace = ir_schedule.GetTraceDesc();
for (auto&& step : trace.Steps()) {
if (step.type == "Bind" && step.attrs.find("thread_axis") != step.attrs.end() &&
if (step.type == "Bind" &&
step.attrs.find("thread_axis") != step.attrs.end() &&
absl::get<std::string>(step.attrs.at("thread_axis")) == bind_axis) {
CHECK_EQ(step.inputs.at("loop").size(), 1);
return step.inputs.at("loop")[0].As<ir::For>()->extent.as_int32();
......@@ -38,9 +40,13 @@ std::vector<std::string> FindCandidates(const ir::ScheduleDesc& trace) {
std::vector<std::string> candidate_block_names;
for (auto&& step : trace.Steps()) {
if (step.type == "AnnotateIntAttr" &&
absl::get<std::string>(step.attrs.at("key")) == ir::attr::cooperative_process) {
absl::get<std::string>(step.attrs.at("key")) ==
ir::attr::cooperative_process) {
candidate_block_names.push_back(
step.inputs.at("block")[0].As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name);
step.inputs.at("block")[0]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name);
}
}
return candidate_block_names;
......
......@@ -20,8 +20,9 @@ namespace cinn {
namespace auto_schedule {
/*
* @brief Rewrite the cooperative_process annotation to actually bind the loop on threadIdx.
* This rule is used for collaborative data handling of multiple threads within the same block.
* @brief Rewrite the cooperative_process annotation to actually bind the loop
* on threadIdx. This rule is used for collaborative data handling of multiple
* threads within the same block.
*/
class CooperativeProcess : public PostScheduleRule {
public:
......
......@@ -44,17 +44,27 @@ TEST_F(TestCooperativeProcess, Matmul) {
int steps_k = 8;
Initialize(common::DefaultNVGPUTarget());
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed);
// split loops
std::vector<ir::Expr> loops = ir_schedule.GetLoops("temp_matmul_out");
std::vector<ir::Expr> k_loops = ir_schedule.Split(loops[2], {steps_k, -1});
std::vector<ir::Expr> j_loops = ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1});
std::vector<ir::Expr> i_loops = ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1});
std::vector<ir::Expr> j_loops =
ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1});
std::vector<ir::Expr> i_loops =
ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1});
// reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2
loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.Reorder({loops[0], loops[3], loops[1], loops[4], loops[6], loops[7], loops[2], loops[5]});
ir_schedule.Reorder({loops[0],
loops[3],
loops[1],
loops[4],
loops[6],
loops[7],
loops[2],
loops[5]});
// fuse and bind
loops = ir_schedule.GetLoops("temp_matmul_out");
ir::Expr i1_j1_fused = ir_schedule.Fuse({loops[2], loops[3]});
......@@ -65,23 +75,31 @@ TEST_F(TestCooperativeProcess, Matmul) {
// cache read
ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out");
ir::Expr X_cache_block = ir_schedule.CacheRead(out_block, 1, "shared");
std::string X_cache_block_name =
X_cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name;
std::string X_cache_block_name = X_cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.ComputeAt(X_cache_block, loops[2]);
std::vector<ir::Expr> X_cache_loops = ir_schedule.GetLoops(X_cache_block_name);
std::vector<ir::Expr> X_cache_loops =
ir_schedule.GetLoops(X_cache_block_name);
ir_schedule.Fuse({X_cache_loops[3], X_cache_loops[4]});
ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name), ir::attr::cooperative_process, 0);
ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name),
ir::attr::cooperative_process,
0);
out_block = ir_schedule.GetBlock("temp_matmul_out");
ir::Expr Y_cache_block = ir_schedule.CacheRead(out_block, 2, "shared");
std::string Y_cache_block_name =
Y_cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name;
std::string Y_cache_block_name = Y_cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.ComputeAt(Y_cache_block, loops[2]);
std::vector<ir::Expr> Y_cache_loops = ir_schedule.GetLoops(Y_cache_block_name);
std::vector<ir::Expr> Y_cache_loops =
ir_schedule.GetLoops(Y_cache_block_name);
ir_schedule.Fuse({Y_cache_loops[3], Y_cache_loops[4]});
ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name), ir::attr::cooperative_process, 0);
ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name),
ir::attr::cooperative_process,
0);
// apply CooperativeProcess
CooperativeProcess cooperative_process;
......@@ -187,7 +205,8 @@ TEST_F(TestCooperativeProcess, Matmul) {
// execute and check precision
CheckResult(
GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(
matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
{X_shape, Y_shape},
......
......@@ -29,21 +29,28 @@ static constexpr uint32_t kMaxBlocks = 256;
bool IsSpatialLoop(const ir::For* for_node) {
if (for_node->for_type() != ir::ForType::Serial) return false;
const auto& loop_var = for_node->loop_var;
// collect cases where the loop_var used in one of reduce axis in underneath ScheduleBlock
auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(for_node->body, [&loop_var](const Expr* x) {
// collect cases where the loop_var used in one of reduce axis in underneath
// ScheduleBlock
auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(
for_node->body, [&loop_var](const Expr* x) {
const auto* block_realize = x->As<ir::ScheduleBlockRealize>();
if (!block_realize) return false;
const auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>();
const auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
CHECK_EQ(block_realize->iter_values.size(), schedule_block->iter_vars.size());
CHECK_EQ(block_realize->iter_values.size(),
schedule_block->iter_vars.size());
for (int i = 0; i < block_realize->iter_values.size(); ++i) {
const ir::Var& iter_var = schedule_block->iter_vars[i];
const ir::Expr& binding = block_realize->iter_values[i];
if (iter_var->is_reduce_axis || iter_var->name.substr(0, 6) == "reduce") {
auto used_exprs = ir::CollectIRNodesWithoutTensor(binding, [&loop_var](const Expr* x) {
if (iter_var->is_reduce_axis ||
iter_var->name.substr(0, 6) == "reduce") {
auto used_exprs = ir::CollectIRNodesWithoutTensor(
binding, [&loop_var](const Expr* x) {
const ir::_Var_* var = x->As<ir::_Var_>();
if (var && (x->same_as(loop_var) || var->name == loop_var->name)) {
if (var &&
(x->same_as(loop_var) || var->name == loop_var->name)) {
return true;
}
return false;
......@@ -59,7 +66,8 @@ bool IsSpatialLoop(const ir::For* for_node) {
return true;
}
// count the number of loops that can be binded from the input for_node to bottom
// count the number of loops that can be binded from the input for_node to
// bottom
int CountLoopCanBinded(const ir::For* for_node) {
int cnt = 0;
while (for_node) {
......@@ -68,9 +76,11 @@ int CountLoopCanBinded(const ir::For* for_node) {
cnt += 1;
CHECK(for_node->body.defined() && for_node->body.As<ir::Block>()) << "Body is not defined";
CHECK(for_node->body.defined() && for_node->body.As<ir::Block>())
<< "Body is not defined";
const ir::Block* body = for_node->body.As<ir::Block>();
// terminate when body of this loop has more than one statement or the body is not a ir::For node
// terminate when body of this loop has more than one statement or the body
// is not a ir::For node
for_node = body->stmts.size() == 1 ? body->stmts[0].As<ir::For>() : nullptr;
}
return cnt;
......@@ -82,13 +92,17 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule,
int max_blocks,
int max_threads_per_block) {
auto all_loops = ir_schedule->GetLoops(block_name);
CHECK_LE(num_loops_to_bind, all_loops.size()) << "The number of loops to be bind is greater than size of all_loops";
// check whether it is the case that threadIdx has been binded but blockIdx not,
// the threadIdx can only be binded in the first loop after num_loops_to_bind loops
// because we has excluded other cases in CountLoopCanBinded
CHECK_LE(num_loops_to_bind, all_loops.size())
<< "The number of loops to be bind is greater than size of all_loops";
// check whether it is the case that threadIdx has been binded but blockIdx
// not, the threadIdx can only be binded in the first loop after
// num_loops_to_bind loops because we has excluded other cases in
// CountLoopCanBinded
bool gpu_thread_has_binded =
num_loops_to_bind < all_loops.size() && all_loops[num_loops_to_bind].As<ir::For>()->is_gpu_thread_binded();
Expr fused_loop = ir_schedule->Fuse({all_loops.begin(), all_loops.begin() + num_loops_to_bind});
num_loops_to_bind < all_loops.size() &&
all_loops[num_loops_to_bind].As<ir::For>()->is_gpu_thread_binded();
Expr fused_loop = ir_schedule->Fuse(
{all_loops.begin(), all_loops.begin() + num_loops_to_bind});
int32_t extent = fused_loop.As<ir::For>()->extent.as_int32();
if (gpu_thread_has_binded) {
ir_schedule->Bind(fused_loop, "blockIdx.x");
......@@ -106,7 +120,8 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule,
ir_schedule->Bind(splits[0], "blockIdx.x");
ir_schedule->Bind(splits[1], "threadIdx.x");
} else {
auto splits = ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block});
auto splits =
ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block});
CHECK_EQ(splits.size(), 3);
ir_schedule->Reorder({splits[1], splits[2], splits[0]});
all_loops = ir_schedule->GetLoops(block_name);
......@@ -126,29 +141,36 @@ RuleApplyType AutoBind::Init(ir::IRSchedule* ir_schedule) {
}
num_applicable_ = applicable_schedule_blocks_.size();
VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
void AutoBind::Apply(int index) {
CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index;
CHECK_LT(index, applicable_schedule_blocks_.size())
<< "invalid apply index:" << index;
auto applied_block = applicable_schedule_blocks_.at(index);
auto all_loops = ir_schedule_->GetLoops(applied_block);
BindGPUIndex(ir_schedule_,
applied_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name,
applied_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name,
CountLoopCanBinded(all_loops[0].As<ir::For>()),
kMaxBlocks,
target_->max_num_threads());
return;
}
RuleApplyType AutoBind::AnalyseApplyType(SearchState state, const std::string& block_name) const {
RuleApplyType AutoBind::AnalyseApplyType(SearchState state,
const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto all_loops = state->ir_schedule.GetLoops(block_expr);
return CountLoopCanBinded(all_loops[0].As<ir::For>()) > 0 ? RuleApplyType::kApplyAndPruneOtherRules
return CountLoopCanBinded(all_loops[0].As<ir::For>()) > 0
? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
std::vector<SearchState> AutoBind::ApplyOnBlock(SearchState state, const std::string& block_name) {
std::vector<SearchState> AutoBind::ApplyOnBlock(SearchState state,
const std::string& block_name) {
SearchState new_state = state.Copy();
auto all_loops = state->ir_schedule.GetLoops(block_name);
BindGPUIndex(&new_state->ir_schedule,
......
......@@ -36,9 +36,11 @@ class AutoBind : public AutoGenRule {
std::string GetRuleName() const override { return "AutoBind"; }
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override;
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override;
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private:
std::vector<Expr> applicable_schedule_blocks_;
......
......@@ -36,9 +36,11 @@ class TestAutoBind : public TestAutoGenRuleBase {
std::vector<std::string> default_input_names = {"X", "Y"};
std::vector<std::string> default_output_names = {"temp_matmul_out"};
void TestApplyOnElementWiseAdd(const std::vector<int>& shape, const std::string& block_name) {
void TestApplyOnElementWiseAdd(const std::vector<int>& shape,
const std::string& block_name) {
Initialize(common::DefaultNVGPUTarget());
auto test_program = tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}});
auto test_program =
tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}});
// construct input parameter
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
SearchState state(ir_schedule, 0, {});
......@@ -48,7 +50,8 @@ class TestAutoBind : public TestAutoGenRuleBase {
// apply
AutoBind auto_bind(target_);
ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name), RuleApplyType::kApplyAndPruneOtherRules);
ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name),
RuleApplyType::kApplyAndPruneOtherRules);
auto result = auto_bind.ApplyOnBlock(state, block_name)[0];
std::vector<ir::Expr> exprs = result->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
......@@ -56,7 +59,8 @@ class TestAutoBind : public TestAutoGenRuleBase {
// check bind result
auto all_loops = result->ir_schedule.GetLoops(block_name);
int total_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int total_num =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
if (total_num <= kMaxThreadsPerBlock) {
ASSERT_EQ(all_loops.size(), 1);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), total_num);
......@@ -64,18 +68,22 @@ class TestAutoBind : public TestAutoGenRuleBase {
} else if (total_num <= kMaxBlocks * kMaxThreadsPerBlock) {
ASSERT_EQ(all_loops.size(), 2);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(),
static_cast<int32_t>(std::ceil(double(total_num) / kMaxThreadsPerBlock)));
static_cast<int32_t>(
std::ceil(double(total_num) / kMaxThreadsPerBlock)));
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded());
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(), kMaxThreadsPerBlock);
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(),
kMaxThreadsPerBlock);
EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded());
} else {
ASSERT_EQ(all_loops.size(), 3);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), kMaxBlocks);
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded());
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(), kMaxThreadsPerBlock);
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(),
kMaxThreadsPerBlock);
EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded());
EXPECT_EQ(all_loops[2].As<ir::For>()->extent.as_int32(),
static_cast<int32_t>(std::ceil(double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock))));
static_cast<int32_t>(std::ceil(
double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock))));
EXPECT_FALSE(all_loops[2].As<ir::For>()->is_binded());
}
......@@ -83,8 +91,10 @@ class TestAutoBind : public TestAutoGenRuleBase {
auto ir_module = BuildIRModule(result->ir_schedule);
auto source_code = GenSourceCode(ir_module);
VLOG(6) << "Optimized source code:\n" << source_code;
auto manual_ir_module = BuildIRModule(MakeIRSchedule(test_program, /* apply_manual_schedule*/ true));
VLOG(6) << "Manual-schedule compiled source code:\n" << GenSourceCode(manual_ir_module);
auto manual_ir_module = BuildIRModule(
MakeIRSchedule(test_program, /* apply_manual_schedule*/ true));
VLOG(6) << "Manual-schedule compiled source code:\n"
<< GenSourceCode(manual_ir_module);
CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(manual_ir_module),
default_input_names,
......@@ -97,16 +107,20 @@ class TestAutoBind : public TestAutoGenRuleBase {
TEST_F(TestAutoBind, AnalyseApplyType) {
Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}}));
ir::IRSchedule ir_schedule = MakeIRSchedule(
tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}}));
SearchState state(ir_schedule, 0, {});
AutoBind auto_bind(target_);
const std::string& applied_block_name = default_output_names.back();
// outer two loops of initial Expr are spatial loops, so it can be applied
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name),
RuleApplyType::kApplyAndPruneOtherRules);
state->ir_schedule.Fuse(applied_block_name, {0, 1});
state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0], "threadIdx.x");
state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0],
"threadIdx.x");
// after fuse and bind, there is no loops to be binded.
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name),
RuleApplyType::kCannotApply);
}
TEST_F(TestAutoBind, ApplyOnBlock) {
......
......@@ -27,12 +27,16 @@ namespace auto_schedule {
AutoGenRule::AutoGenRule(const common::Target& target) : target_(&target) {}
int AutoGenRule::NumberApplicable() const {
CHECK_GE(num_applicable_, 0) << "Call " << GetRuleName() << "::NumberApplicable() without initialization.";
CHECK_GE(num_applicable_, 0)
<< "Call " << GetRuleName()
<< "::NumberApplicable() without initialization.";
return num_applicable_;
}
void AutoGenRule::ApplyRandomly() {
CHECK_GT(num_applicable_, 0) << "Call " << GetRuleName() << "::ApplyRandomly() with NumberApplicable() == 0";
CHECK_GT(num_applicable_, 0)
<< "Call " << GetRuleName()
<< "::ApplyRandomly() with NumberApplicable() == 0";
int index = rand() % num_applicable_;
return Apply(index);
}
......
......@@ -29,15 +29,18 @@ enum class RuleApplyType : int {
// This rule cannot be applied to ModuleExpr.
kCannotApply = 0,
// This rule can be applied to ModuleExpr,
// and the original ModuleExpr will be retained for branching with other rules.
// and the original ModuleExpr will be retained for branching with other
// rules.
kApply = 1,
// This rule can be applied, but the original ModuleExpr will be deleted,
// so the branches with other rules applied on the original ModuleExpr will be pruned.
// so the branches with other rules applied on the original ModuleExpr will be
// pruned.
kApplyAndPruneOtherRules = 2,
};
/**
* Base class for rules of auto-generating schedule (like Ansor's sketch generation)
* Base class for rules of auto-generating schedule (like Ansor's sketch
* generation)
*
*/
class AutoGenRule {
......@@ -46,7 +49,8 @@ class AutoGenRule {
~AutoGenRule() = default;
// Initialize the AutoGenRule, it must be called before further actions.
// Returns false if the rule cannot be applied on the mod_expr, true otherwise.
// Returns false if the rule cannot be applied on the mod_expr, true
// otherwise.
virtual RuleApplyType Init(ir::IRSchedule* ir_schedule) = 0;
// CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so
......@@ -65,11 +69,15 @@ class AutoGenRule {
// Returns the name of the rule, used for debug.
virtual std::string GetRuleName() const = 0;
// Analyze the ApplyType of the rule used for a block determined by a specific SearchState and block name
virtual RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const = 0;
// Analyze the ApplyType of the rule used for a block determined by a specific
// SearchState and block name
virtual RuleApplyType AnalyseApplyType(
SearchState state, const std::string& block_name) const = 0;
// Apply the rule to a block determined by a specific SearchState and block name
virtual std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) = 0;
// Apply the rule to a block determined by a specific SearchState and block
// name
virtual std::vector<SearchState> ApplyOnBlock(
SearchState state, const std::string& block_name) = 0;
protected:
// number of ScheduleBlock that can apply this auto gen rule
......
......@@ -34,18 +34,23 @@
namespace cinn {
namespace auto_schedule {
AutoInline::AutoInline(const common::Target& target, const std::unordered_set<std::string>& no_inline_output_names)
AutoInline::AutoInline(
const common::Target& target,
const std::unordered_set<std::string>& no_inline_output_names)
: AutoGenRule(target), no_inline_output_names_(no_inline_output_names) {}
bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlockRealize* sche_block_realize =
sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
ir::Expr compute_body = sche_block->body;
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);
// Check the schedule block to be inlined is not a reduce tensor.
std::set<ir::Expr> find_store =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
std::set<ir::Expr> find_store = ir::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) {
return false;
}
......@@ -57,8 +62,10 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::
}
// LoweredFunc output can be tensor name or tensor buffer name
if (no_inline_output_names_.find(tensor->name) != no_inline_output_names_.end() ||
no_inline_output_names_.find(tensor->buffer->name) != no_inline_output_names_.end()) {
if (no_inline_output_names_.find(tensor->name) !=
no_inline_output_names_.end() ||
no_inline_output_names_.find(tensor->buffer->name) !=
no_inline_output_names_.end()) {
return false;
}
......@@ -70,26 +77,32 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::
// Check this schedule block is the only writer of the tensor.
find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() && (x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name;
return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name;
});
if (find_store.size() != 1UL) {
return false;
}
// Check there is no overlap between the buffers the schedule block reads and writes.
std::set<ir::Expr> find_load = ir::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr; });
// Check there is no overlap between the buffers the schedule block reads and
// writes.
std::set<ir::Expr> find_load =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr;
});
if (!find_load.empty()) {
return false;
}
ir::Expr store = *(find_store.begin());
ir::ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(), store);
ir::ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(),
store);
if (!inliner.BodyPatternAllowInline()) {
return false;
}
ir::LeafBlockRemovalPlan remove_plan(sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt);
ir::LeafBlockRemovalPlan remove_plan(
sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt);
remove_plan(&root);
if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) {
return false;
......@@ -99,16 +112,20 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::
return true;
}
AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AutoInlineType AutoInline::AnalyzeInlineType(
const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlockRealize* sche_block_realize =
sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
// Inline if the block has only 1 write buffer
if (sche_block->write_buffers.size() != 1) {
return AutoInlineType::kCannotInline;
}
std::unordered_set<ir::IrNodeTy> no_inline_node_types = {ir::IrNodeTy::IfThenElse};
std::unordered_set<ir::IrNodeTy> no_inline_node_types = {
ir::IrNodeTy::IfThenElse};
if (ContainsNodeType(sche_block->body, no_inline_node_types)) {
return AutoInlineType::kCannotInline;
}
......@@ -131,25 +148,32 @@ RuleApplyType AutoInline::Init(ir::IRSchedule* ir_schedule) {
num_applicable_ = 0;
for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(all_block_realizes_[i], ir_schedule_);
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AnalyzeScheduleBlockReadWriteBuffer(
sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type =
AnalyzeInlineType(all_block_realizes_[i], ir_schedule_);
if (type != AutoInlineType::kCannotInline) {
++num_applicable_;
apply_indices_and_type_.push_back({i, type});
}
}
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
void AutoInline::Apply(int index) {
CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init";
CHECK(num_applicable_ > 0 && apply_indices_and_type_.size() == num_applicable_)
CHECK(num_applicable_ > 0 &&
apply_indices_and_type_.size() == num_applicable_)
<< "AutoInline::Apply pre-condition doesn't meet";
CHECK(index >= 0 && num_applicable_ > index)
<< "Invalid index for AutoInline::Apply, the index needs 0 <= index && index < NumberApplicable(), "
<< "Currently index = " << index << ", NumberApplicable() = " << num_applicable_;
<< "Invalid index for AutoInline::Apply, the index needs 0 <= index && "
"index < NumberApplicable(), "
<< "Currently index = " << index
<< ", NumberApplicable() = " << num_applicable_;
int apply_index = apply_indices_and_type_[index].first;
Apply(ir_schedule_, all_block_realizes_[apply_index]);
......@@ -158,18 +182,23 @@ void AutoInline::Apply(int index) {
std::string AutoInline::GetRuleName() const { return "AutoInline"; }
RuleApplyType AutoInline::AnalyseApplyType(SearchState state, const std::string& block_name) const {
RuleApplyType AutoInline::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As<ir::ScheduleBlock>());
AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(block_expr, &state->ir_schedule);
return type == AutoInlineType::kCannotInline ? RuleApplyType::kCannotApply : RuleApplyType::kApplyAndPruneOtherRules;
return type == AutoInlineType::kCannotInline
? RuleApplyType::kCannotApply
: RuleApplyType::kApplyAndPruneOtherRules;
}
std::vector<SearchState> AutoInline::ApplyOnBlock(SearchState state, const std::string& block_name) {
std::vector<SearchState> AutoInline::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
Expr block_expr = new_state->ir_schedule.GetBlock(block_name);
Apply(&new_state->ir_schedule, block_expr);
......@@ -181,7 +210,8 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As<ir::ScheduleBlock>());
AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(block_expr, ir_schedule);
if (type == AutoInlineType::kInlineIntoConsumer) {
......@@ -202,8 +232,10 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
// we need to re-analyze
all_block_realizes_ = ir_schedule->GetAllBlocks();
for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
sche_block->read_buffers = {};
sche_block->write_buffers = {};
AnalyzeScheduleBlockReadWriteBuffer(sche_block);
......
......@@ -41,7 +41,8 @@ enum class AutoInlineType : int {
class AutoInline : public AutoGenRule {
public:
AutoInline(const common::Target& target, const std::unordered_set<std::string>& no_inline_output_names);
AutoInline(const common::Target& target,
const std::unordered_set<std::string>& no_inline_output_names);
~AutoInline() = default;
RuleApplyType Init(ir::IRSchedule* ir_schedule) override;
......@@ -50,13 +51,17 @@ class AutoInline : public AutoGenRule {
std::string GetRuleName() const override;
AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const;
AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;
bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const;
bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override;
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override;
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private:
void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr);
......
......@@ -63,7 +63,14 @@ TEST(AutoInline, SingleLoopInline) {
poly::StageMap stages = CreateStages({A, B, C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestAutoInline_SingleLoopInline", stages, {A, C}, {}, {}, nullptr, target, true);
lang::LowerVec("TestAutoInline_SingleLoopInline",
stages,
{A, C},
{},
{},
nullptr,
target,
true);
VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body;
......@@ -90,7 +97,8 @@ TEST(AutoInline, SingleLoopInline) {
EXPECT_EQ(exprs.size(), 1UL);
// ApplyOnBlock
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "B");
auto test_func = [](ir::IRSchedule* ir_sch) {
......@@ -130,7 +138,8 @@ TEST(AutoInline, SingleLoopInline) {
// Cannot inline above expr again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"),
RuleApplyType::kCannotApply);
}
TEST(AutoInline, AddReluInline) {
......@@ -151,12 +160,17 @@ TEST(AutoInline, AddReluInline) {
auto graph = std::make_shared<Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target);
const auto& dtype_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs = op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]);
std::vector<ir::LoweredFunc> funcs =
op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]);
VLOG(6) << "Expr before auto inline: " << funcs[0]->body;
......@@ -186,10 +200,12 @@ TEST(AutoInline, AddReluInline) {
auto_inline.Apply(0);
// ApplyOnBlock
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_1");
// Auto Inline again
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"),
RuleApplyType::kApplyAndPruneOtherRules);
new_states = auto_inline.ApplyOnBlock(new_states[0], "var_3");
auto test_func = [](ir::IRSchedule* ir_sch) {
......@@ -238,7 +254,8 @@ TEST(AutoInline, AddReluInline) {
// Cannot inline above expr again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"),
RuleApplyType::kCannotApply);
}
#ifdef CINN_WITH_CUDA
......@@ -246,14 +263,8 @@ class TestAutoInline : public TestAutoGenRuleBase {};
/* The single chain graph composed of multiple blocks can be inlined into one.
*
* Before AutoInline: The output of the previous block is the input of another block.
* Loop1:
* x1 = Add()
* Loop2:
* x2 = Multiply(x1)
* Loop3:
* x3 = Add(x2)
* Loop4:
* Before AutoInline: The output of the previous block is the input of another
* block. Loop1: x1 = Add() Loop2: x2 = Multiply(x1) Loop3: x3 = Add(x2) Loop4:
* x4 = Relu(x3)
*
* After AutoInline: All loops are inlined into a loop.
......@@ -263,18 +274,22 @@ class TestAutoInline : public TestAutoGenRuleBase {};
TEST_F(TestAutoInline, SingleChain) {
Target target = common::DefaultNVGPUTarget();
Initialize(target);
std::vector<std::string> input_names = {"bias", "conv_output", "bn_scale", "bn_offset"};
std::vector<std::string> output_names = {"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"};
std::vector<std::string> input_names = {
"bias", "conv_output", "bn_scale", "bn_offset"};
std::vector<std::string> output_names = {
"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"};
std::vector<int32_t> conv_output_shape = {1, 512, 56, 56};
int32_t channel = conv_output_shape[1];
std::vector<tests::VariableInfo> inputs_varinfo({{"conv_output", conv_output_shape},
std::vector<tests::VariableInfo> inputs_varinfo(
{{"conv_output", conv_output_shape},
{"bias", {channel, 1, 1}},
{"bn_scale", {channel, 1, 1}},
{"bn_offset", {channel, 1, 1}}});
// Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo));
ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
......@@ -282,20 +297,23 @@ TEST_F(TestAutoInline, SingleChain) {
// Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_3");
std::vector<std::string> inline_block_names({"var_4", "var_5", "var_6", "var", "var_0", "var_1"});
std::vector<std::string> inline_block_names(
{"var_4", "var_5", "var_6", "var", "var_0", "var_1"});
for (const auto& inline_block_name : inline_block_names) {
new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name);
}
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually =
BuildIRModule(MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true));
auto build_module_manually = BuildIRModule(MakeIRSchedule(
tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually);
......@@ -305,7 +323,10 @@ TEST_F(TestAutoInline, SingleChain) {
GenExecutableKernel(build_module_manually),
input_names,
output_names,
{{conv_output_shape[1], 1, 1}, conv_output_shape, conv_output_shape, conv_output_shape},
{{conv_output_shape[1], 1, 1},
conv_output_shape,
conv_output_shape,
conv_output_shape},
{conv_output_shape, {1}, {1}, {1}, {1}, {1}, {1}},
target);
}
......@@ -335,7 +356,8 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) {
// Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo));
ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
......@@ -343,17 +365,19 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) {
// Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_1");
new_states = auto_inline.ApplyOnBlock(state, "var_0");
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually =
BuildIRModule(MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true));
auto build_module_manually = BuildIRModule(MakeIRSchedule(
tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually);
......@@ -387,14 +411,20 @@ TEST_F(TestAutoInline, OnlySpatialOp) {
Target target = common::DefaultNVGPUTarget();
Initialize(target);
std::vector<std::string> input_names = {"x", "y"};
std::vector<std::string> output_names = {
"var_6", "var_4", "constant_idx_last", "constant_idx_first", "var_2", "var_5"};
std::vector<std::string> output_names = {"var_6",
"var_4",
"constant_idx_last",
"constant_idx_first",
"var_2",
"var_5"};
std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}, {"y", input_shape}});
std::vector<tests::VariableInfo> inputs_varinfo(
{{"x", input_shape}, {"y", input_shape}});
// Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo));
ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
......@@ -402,20 +432,23 @@ TEST_F(TestAutoInline, OnlySpatialOp) {
// Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "constant_idx_first");
std::vector<std::string> inline_block_names({"constant_idx_last", "var_2", "var_5", "var_4"});
std::vector<std::string> inline_block_names(
{"constant_idx_last", "var_2", "var_5", "var_4"});
for (const auto& inline_block_name : inline_block_names) {
new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name);
}
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually =
BuildIRModule(MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true));
auto build_module_manually = BuildIRModule(MakeIRSchedule(
tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually);
......@@ -451,7 +484,8 @@ TEST_F(TestAutoInline, NoReadBufferOp) {
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}});
// Construct the computation graph and convert it to ir::Expr
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo));
ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
......@@ -459,16 +493,18 @@ TEST_F(TestAutoInline, NoReadBufferOp) {
// Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant");
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually =
BuildIRModule(MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true));
auto build_module_manually = BuildIRModule(MakeIRSchedule(
tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually);
......
......@@ -33,11 +33,13 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
auto has_reduce_iter = [](const Expr* x) {
auto* block_realize = x->As<ir::ScheduleBlockRealize>();
if (block_realize) {
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>();
auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
for (auto&& var : schedule_block->iter_vars) {
if (var->is_reduce_axis) {
VLOG(6) << "find ScheduleBlockRealize:" << *x << " has reduce_axis:" << var;
VLOG(6) << "find ScheduleBlockRealize:" << *x
<< " has reduce_axis:" << var;
return true;
}
}
......@@ -46,7 +48,8 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
};
// whether has any for-loop with non-serial type
auto has_nonserial_loop = [](const Expr* x) {
if (x->As<ir::For>() && x->As<ir::For>()->for_type() != ir::ForType::Serial) {
if (x->As<ir::For>() &&
x->As<ir::For>()->for_type() != ir::ForType::Serial) {
VLOG(6) << "find non-serial loop:" << *x;
return true;
}
......@@ -55,7 +58,9 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
auto find_target_exprs = ir::CollectIRNodesWithoutTensor(
schedule_block->body,
[&has_reduce_iter, &has_nonserial_loop](const Expr* x) { return has_reduce_iter(x) || has_nonserial_loop(x); });
[&has_reduce_iter, &has_nonserial_loop](const Expr* x) {
return has_reduce_iter(x) || has_nonserial_loop(x);
});
return !find_target_exprs.empty();
}
......@@ -74,44 +79,55 @@ RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) {
Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]);
auto* block_realize = root_block.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block;
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize);
auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:"
<< Expr(block_realize);
if (MeetCondition(schedule_block)) {
deduplicate_results.emplace(root_block);
}
}
applicable_schedule_blocks_ = {deduplicate_results.begin(), deduplicate_results.end()};
applicable_schedule_blocks_ = {deduplicate_results.begin(),
deduplicate_results.end()};
num_applicable_ = applicable_schedule_blocks_.size();
VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
void AutoUnroll::Apply(int index) {
CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index;
CHECK_LT(index, applicable_schedule_blocks_.size())
<< "invalid apply index:" << index;
auto applied_block = applicable_schedule_blocks_.at(index);
int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()];
ir_schedule_->Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step);
ir_schedule_->Annotate(
applied_block, ir::attr::auto_unroll_max_step, max_step);
return;
}
RuleApplyType AutoUnroll::AnalyseApplyType(SearchState state, const std::string& block_name) const {
RuleApplyType AutoUnroll::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
Expr root_block = state->ir_schedule.GetRootBlock(block_expr);
auto* block_realize = root_block.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block;
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize);
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:"
<< Expr(block_realize);
return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply;
return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
std::vector<SearchState> AutoUnroll::ApplyOnBlock(SearchState state, const std::string& block_name) {
std::vector<SearchState> AutoUnroll::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
Expr block_expr = new_state->ir_schedule.GetBlock(block_name);
Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr);
int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()];
new_state->ir_schedule.Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step);
new_state->ir_schedule.Annotate(
applied_block, ir::attr::auto_unroll_max_step, max_step);
return {new_state};
}
......
......@@ -24,10 +24,11 @@
namespace cinn {
namespace auto_schedule {
// This rule can be applied in a ScheduleBlock has reduce axis or has loops with non-serial type.
// As a result, it will set a attribute with key named ir::attr::auto_unroll_max_step and value
// indicating max permitted unrolled step in the applied ScheduleBlock. Finally, UnrollLoop pass
// will do unroll based on actual situation.
// This rule can be applied in a ScheduleBlock has reduce axis or has loops with
// non-serial type. As a result, it will set a attribute with key named
// ir::attr::auto_unroll_max_step and value indicating max permitted unrolled
// step in the applied ScheduleBlock. Finally, UnrollLoop pass will do unroll
// based on actual situation.
class AutoUnroll : public AutoGenRule {
public:
AutoUnroll(const common::Target& target) : AutoGenRule(target) {}
......@@ -39,9 +40,11 @@ class AutoUnroll : public AutoGenRule {
std::string GetRuleName() const override { return "AutoUnroll"; }
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override;
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override;
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private:
bool MeetCondition(const ir::ScheduleBlock* schedule_block) const;
......
......@@ -39,7 +39,8 @@ TEST(AutoUnroll, Init) {
Target target = common::DefaultHostTarget();
#endif
auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec("test_init", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto funcs = cinn::lang::LowerVec(
"test_init", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto ast_expr = funcs[0]->body;
ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr}));
......@@ -58,7 +59,9 @@ TEST(AutoUnroll, UnrollableApply) {
Placeholder<float> B("B", {K, N});
Var k(K.as_int32(), "k0");
Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C");
{M, N},
[&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
......@@ -66,11 +69,14 @@ TEST(AutoUnroll, UnrollableApply) {
Target target = common::DefaultHostTarget();
#endif
auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec("test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto funcs = cinn::lang::LowerVec(
"test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto ast_expr = funcs[0]->body;
auto* init_block_realize = ast_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>();
auto* init_schedule_block = init_block_realize->schedule_block.As<ir::ScheduleBlock>();
auto* init_block_realize =
ast_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>();
auto* init_schedule_block =
init_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_NE(init_schedule_block, nullptr);
ASSERT_TRUE(init_schedule_block->attrs.empty());
VLOG(6) << "Before auto-unroll:\n" << ast_expr;
......@@ -78,25 +84,34 @@ TEST(AutoUnroll, UnrollableApply) {
AutoUnroll test_rule(target);
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {});
ASSERT_EQ(test_rule.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules);
ASSERT_EQ(test_rule.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(test_rule.NumberApplicable(), 1);
test_rule.ApplyRandomly();
// ApplyOnBlock
EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules);
std::vector<cinn::auto_schedule::SearchState> states = test_rule.ApplyOnBlock(state, "C");
EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules);
std::vector<cinn::auto_schedule::SearchState> states =
test_rule.ApplyOnBlock(state, "C");
auto test_func = [](IRSchedule* ir_sch) {
Expr applied_expr = ir_sch->GetModule().GetExprs().front();
auto* applied_block_realize = applied_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>();
auto* applied_schedule_block = applied_block_realize->schedule_block.As<ir::ScheduleBlock>();
auto* applied_block_realize = applied_expr.As<ir::Block>()
->stmts.front()
.As<ir::ScheduleBlockRealize>();
auto* applied_schedule_block =
applied_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_FALSE(applied_schedule_block->attrs.empty());
EXPECT_EQ(applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1);
const auto& attr_value = applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step);
EXPECT_EQ(
applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1);
const auto& attr_value =
applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step);
const int* max_step = absl::get_if<int>(&attr_value);
EXPECT_NE(max_step, nullptr);
EXPECT_LE(*max_step, 128);
VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front();
VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n"
<< ir_sch->GetModule().GetExprs().front();
};
test_func(&ir_schedule);
......
......@@ -34,7 +34,8 @@ class TestMixRules : public TestAutoGenRuleBase {
};
TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) {
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}});
frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}});
Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op);
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
......@@ -42,7 +43,8 @@ TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) {
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch));
MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
multi_level_tiling.Init(&ir_schedule);
ASSERT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly();
......@@ -54,7 +56,8 @@ TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) {
VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision
CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))),
GenExecutableKernel(BuildIRModule(
MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))),
default_input_names,
default_output_names,
{{32, 32}, {32, 32}},
......
......@@ -38,7 +38,8 @@
namespace cinn {
namespace auto_schedule {
MultiLevelTiling::MultiLevelTiling(const common::Target& target, const Config& config)
MultiLevelTiling::MultiLevelTiling(const common::Target& target,
const Config& config)
: AutoGenRule(target), config_(config) {
for (int i = 0; i < config_.tile_struct.size(); ++i) {
if (config_.tile_struct[i] == 'S') {
......@@ -51,7 +52,8 @@ MultiLevelTiling::MultiLevelTiling(const common::Target& target, const Config& c
}
}
bool MultiLevelTiling::MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const {
bool MultiLevelTiling::MeetCondition(
const ir::ScheduleBlockRealize& sche_block_realize) const {
return NeedsMultiLevelTiling(sche_block_realize);
}
......@@ -61,15 +63,18 @@ RuleApplyType MultiLevelTiling::Init(ir::IRSchedule* ir_schedule) {
applicable_indices_.clear();
num_applicable_ = 0;
for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
ir::ScheduleBlockRealize* sche_block_realize =
all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AnalyzeScheduleBlockReadWriteBuffer(
sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
if (MeetCondition(*sche_block_realize)) {
++num_applicable_;
applicable_indices_.push_back(i);
}
}
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
void MultiLevelTiling::Apply(int index) {
......@@ -77,12 +82,16 @@ void MultiLevelTiling::Apply(int index) {
CHECK(num_applicable_ > 0 && applicable_indices_.size() == num_applicable_)
<< "MultiLevelTiling::Apply pre-condition doesn't meet";
CHECK(index >= 0 && num_applicable_ > index)
<< "Invalid index for MultiLevelTiling::Apply, the index needs 0 <= index && index < NumberApplicable(), "
<< "Currently index = " << index << ", NumberApplicable() = " << num_applicable_;
<< "Invalid index for MultiLevelTiling::Apply, the index needs 0 <= "
"index && index < NumberApplicable(), "
<< "Currently index = " << index
<< ", NumberApplicable() = " << num_applicable_;
int apply_index = applicable_indices_[index];
std::string block_name =
all_block_realizes_[apply_index].As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name;
std::string block_name = all_block_realizes_[apply_index]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
Expr block_expr = all_block_realizes_[apply_index];
ApplyTiling(ir_schedule_, block_expr);
block_expr = ir_schedule_->GetBlock(block_name);
......@@ -96,16 +105,21 @@ void MultiLevelTiling::Apply(int index) {
std::string MultiLevelTiling::GetRuleName() const { return "MultiLevelTiling"; }
RuleApplyType MultiLevelTiling::AnalyseApplyType(SearchState state, const std::string& block_name) const {
RuleApplyType MultiLevelTiling::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As<ir::ScheduleBlock>());
AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
return NeedsMultiLevelTiling(*block_realize) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply;
return NeedsMultiLevelTiling(*block_realize)
? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
std::vector<SearchState> MultiLevelTiling::ApplyOnBlock(SearchState state, const std::string& block_name) {
std::vector<SearchState> MultiLevelTiling::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy();
ir::IRSchedule* ir_sch = &new_state->ir_schedule;
Expr block_expr = ir_sch->GetBlock(block_name);
......@@ -119,14 +133,18 @@ std::vector<SearchState> MultiLevelTiling::ApplyOnBlock(SearchState state, const
return {new_state};
}
void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
ir::ScheduleBlockRealize* sche_block_realize = block_expr.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr) {
ir::ScheduleBlockRealize* sche_block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
tile_loops_.clear();
tile_loops_.resize(config_.tile_struct.size());
std::vector<Expr> for_exprs = ir_schedule->GetLoops(block_expr);
VLOG(5) << "The number of loops to split in MultiLevelTiling is " << for_exprs.size();
VLOG(5) << "The number of loops to split in MultiLevelTiling is "
<< for_exprs.size();
for (int i = for_exprs.size() - 1; i >= 0; --i) {
ir::For* ir_for = for_exprs[i].As<ir::For>();
VLOG(6) << "Applying Split for MultiLevelTiling on: " << Expr(ir_for);
......@@ -141,8 +159,10 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
int num_split = idx->size();
if (num_split > 1) {
std::vector<Expr> tile_split_factor = ir_schedule->SamplePerfectTile(Expr(ir_for), num_split, 64);
std::vector<Expr> splited = ir_schedule->Split(Expr(ir_for), tile_split_factor);
std::vector<Expr> tile_split_factor =
ir_schedule->SamplePerfectTile(Expr(ir_for), num_split, 64);
std::vector<Expr> splited =
ir_schedule->Split(Expr(ir_for), tile_split_factor);
VLOG(6) << "Finish Split for MultiLevelTiling on above loop";
for (int j = 0; j < num_split; ++j) {
tile_loops_[idx->at(j)].push_back(splited[j]);
......@@ -159,7 +179,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
for (int i = 0; i < for_exprs.size(); ++i) {
loop_var_name_to_idx[for_exprs[i].As<ir::For>()->loop_var->name] = i;
}
CHECK(loop_var_name_to_idx.size() == for_exprs.size()) << "Loops contain duplicate loop var names after split";
CHECK(loop_var_name_to_idx.size() == for_exprs.size())
<< "Loops contain duplicate loop var names after split";
std::vector<Expr> splited_loops;
for (auto& t : tile_loops_) {
......@@ -173,7 +194,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
}
Expr reordered_expr = ir_schedule->Reorder(splited_loops);
VLOG(5) << "Finish Reorder in MultiLevelTiling, now do Fuse and Binding on the main loop chain";
VLOG(5) << "Finish Reorder in MultiLevelTiling, now do Fuse and Binding on "
"the main loop chain";
int num_binds = std::min(config_.bind_axis.size(), tile_loops_.size());
for (int i = 0; i < num_binds; ++i) {
......@@ -182,7 +204,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
for (int j = 0; j < for_exprs.size(); ++j) {
loop_var_name_to_idx[for_exprs[j].As<ir::For>()->loop_var->name] = j;
}
CHECK(loop_var_name_to_idx.size() == for_exprs.size()) << "Loops contain duplicate loop var names before Fusion";
CHECK(loop_var_name_to_idx.size() == for_exprs.size())
<< "Loops contain duplicate loop var names before Fusion";
// Some loops extent may exceed the limited max factor (For example,
// exceed the limit number of CUDA threads), here we check whether
......@@ -209,7 +232,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
Expr fused = ir_schedule->Fuse(tile_loops_[i]);
ir_schedule->Bind(fused, config_.bind_axis[i]);
} else if (first_idx_less_than_max_factor != -1) {
ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], config_.bind_axis[i]);
ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor],
config_.bind_axis[i]);
}
}
......@@ -229,13 +253,17 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
}
}
if (!other_loop_chain_schedule.defined()) {
LOG(WARNING) << "Has non-main loop chain, but not corresponding ScheduleBlock in MultiLevelTiling";
LOG(WARNING) << "Has non-main loop chain, but not corresponding "
"ScheduleBlock in MultiLevelTiling";
continue;
}
std::string other_loop_schedule_name =
other_loop_chain_schedule.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name;
VLOG(6) << "Found other_loop_schedule_name = " << other_loop_schedule_name;
other_loop_chain_schedule.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
VLOG(6) << "Found other_loop_schedule_name = "
<< other_loop_schedule_name;
int fuse_index = 0;
for (int i = 0; i < num_binds; ++i) {
for_exprs = ir_schedule->GetLoops(other_loop_schedule_name);
......@@ -250,20 +278,23 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
int extent_prod = 1;
int first_idx_less_than_max_factor = -1;
for (int j = 0; j < tile_loops_[i].size(); ++j) {
int extent = for_exprs[fuse_index + j].As<ir::For>()->extent.as_int32();
int extent =
for_exprs[fuse_index + j].As<ir::For>()->extent.as_int32();
extent_prod *= extent;
if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) {
first_idx_less_than_max_factor = fuse_index + j;
}
}
if (extent_prod <= max_factor_) {
std::vector<Expr> loops_to_fuse(for_exprs.begin() + fuse_index,
std::vector<Expr> loops_to_fuse(
for_exprs.begin() + fuse_index,
for_exprs.begin() + fuse_index + tile_loops_[i].size());
Expr fused = ir_schedule->Fuse(loops_to_fuse);
ir_schedule->Bind(fused, config_.bind_axis[i]);
fuse_index += 1;
} else if (first_idx_less_than_max_factor != -1) {
ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], config_.bind_axis[i]);
ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor],
config_.bind_axis[i]);
fuse_index += tile_loops_[i].size();
}
}
......@@ -272,9 +303,12 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
}
}
void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
ir::ScheduleBlockRealize* sch_block_realize = block_expr.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sch_block = sch_block_realize->schedule_block.As<ir::ScheduleBlock>();
void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr) {
ir::ScheduleBlockRealize* sch_block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sch_block =
sch_block_realize->schedule_block.As<ir::ScheduleBlock>();
std::string block_name = sch_block->name;
// Analyze which buffers can be cached
......@@ -302,85 +336,110 @@ void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& blo
}
// 2.Do CacheRead and get the cache block
ir::Expr cache_block = ir_schedule->CacheRead(block_expr, read_buffer_index, config_.read_cache_memory_type);
ir::Expr cache_block = ir_schedule->CacheRead(
block_expr, read_buffer_index, config_.read_cache_memory_type);
std::string cache_block_name =
cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name;
cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
std::string target_for_loop_name = loops.back().As<ir::For>()->loop_var->name;
std::string target_for_loop_name =
loops.back().As<ir::For>()->loop_var->name;
// 3.Place the cache_block under target_for_loop
// The original block expr is invalid after the CacheRead schedule,
// so we reacquire the block expr after the schedule according to the block name
// so we reacquire the block expr after the schedule according to the
// block name
block_expr = ir_schedule->GetBlock(block_name);
std::vector<Expr> for_exprs = ir_schedule->GetLoops(block_expr);
for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) != std::string::npos) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
std::string::npos) {
ir_schedule->ComputeAt(cache_block, for_expr, true);
break;
}
}
// 4.Threads under the same block cooperative fetch data from global memory.
// 4.Threads under the same block cooperative fetch data from global
// memory.
Expr new_cache_block = ir_schedule->GetBlock(cache_block_name);
auto cache_block_loops = ir_schedule->GetLoops(new_cache_block);
std::vector<std::string> compute_at_extra_var = utils::Split(
absl::get<std::string>(
new_cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->attrs.at(
"compute_at_extra_var")),
absl::get<std::string>(new_cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->attrs.at("compute_at_extra_var")),
",");
std::vector<Expr> buffer_loops;
// int nthreads = 1;
for (const Expr& for_expr : cache_block_loops) {
if (std::find(compute_at_extra_var.begin(),
compute_at_extra_var.end(),
for_expr.As<ir::For>()->loop_var->name) != compute_at_extra_var.end()) {
for_expr.As<ir::For>()->loop_var->name) !=
compute_at_extra_var.end()) {
buffer_loops.push_back(for_expr);
}
}
auto fused_buffer_loop = ir_schedule->Fuse(buffer_loops);
// TODO(BiynXu): Implement vectorize fetching data and pass in vector length
ir_schedule->Annotate(ir_schedule->GetBlock(cache_block_name), ir::attr::cooperative_process, 0);
// TODO(BiynXu): Implement vectorize fetching data and pass in vector
// length
ir_schedule->Annotate(ir_schedule->GetBlock(cache_block_name),
ir::attr::cooperative_process,
0);
}
}
}
void MultiLevelTiling::ApplyCacheWrite(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
ir::Expr cache_block = ir_schedule->CacheWrite(block_expr, 0, config_.write_cache_memory_type);
void MultiLevelTiling::ApplyCacheWrite(ir::IRSchedule* ir_schedule,
ir::Expr& block_expr) {
ir::Expr cache_block =
ir_schedule->CacheWrite(block_expr, 0, config_.write_cache_memory_type);
for (int level : config_.write_cache_levels) {
const auto loops = tile_loops_.at(level - 1);
if (loops.size() == 0) {
continue;
}
std::string target_for_loop_name = loops.back().As<ir::For>()->loop_var->name;
// Because the block name is changed in CacheWrite, we need to calculate the derived name
// according to the logic of CacheWrite and find the loop structure according to the derived name.
std::string target_for_loop_name =
loops.back().As<ir::For>()->loop_var->name;
// Because the block name is changed in CacheWrite, we need to calculate the
// derived name according to the logic of CacheWrite and find the loop
// structure according to the derived name.
const std::string original_block_name =
block_expr.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name;
const std::string derivative_block_name =
original_block_name + "_" + config_.write_cache_memory_type + "_temp_buffer";
block_expr.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
const std::string derivative_block_name = original_block_name + "_" +
config_.write_cache_memory_type +
"_temp_buffer";
std::vector<Expr> for_exprs = ir_schedule->GetLoops(derivative_block_name);
for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) != std::string::npos) {
ir_schedule->ReverseComputeAt(ir_schedule->GetBlock(original_block_name), for_expr, true);
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
std::string::npos) {
ir_schedule->ReverseComputeAt(
ir_schedule->GetBlock(original_block_name), for_expr, true);
}
}
const std::string reduce_init_block_name = original_block_name + "__reduce_init";
const std::string reduce_init_block_name =
original_block_name + "__reduce_init";
for_exprs = ir_schedule->GetLoops(derivative_block_name);
for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) != std::string::npos &&
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
std::string::npos &&
ir_schedule->HasBlock(reduce_init_block_name)) {
ir_schedule->SimpleComputeAt(ir_schedule->GetBlock(reduce_init_block_name), for_expr);
ir_schedule->SimpleComputeAt(
ir_schedule->GetBlock(reduce_init_block_name), for_expr);
}
}
}
}
const std::unordered_map<common::Target::Arch, MultiLevelTiling::Config> MultiLevelTiling::kConfigs{
const std::unordered_map<common::Target::Arch, MultiLevelTiling::Config>
MultiLevelTiling::kConfigs{
{common::Target::Arch::NVGPU,
MultiLevelTiling::Config{
/*bind_axis*/ std::vector<std::string>{"blockIdx.x", "threadIdx.x"},
/*bind_axis*/ std::vector<std::string>{"blockIdx.x",
"threadIdx.x"},
/*tile_struct*/ std::string("SSSRRSRS"),
/*read_cache_memory_type*/ std::string("shared"),
/*read_cache_levels*/ std::vector<int>{4},
......
......@@ -72,9 +72,11 @@ class MultiLevelTiling : public AutoGenRule {
// Returns true if sche_block_realize is applicable by MultiLevelTiling
bool MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override;
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override;
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
// Sample pair of integer type (a, b) such as a * b = extent
template <typename T>
......@@ -101,7 +103,8 @@ class MultiLevelTiling : public AutoGenRule {
// Sample num_split integers whose product equals extent
template <typename T>
std::vector<T> SampleTileSplit(T extent, int num_split) const {
CHECK_GT(num_split, 0) << "num_split in SampleTileSplit must be greater than 0";
CHECK_GT(num_split, 0)
<< "num_split in SampleTileSplit must be greater than 0";
if (num_split == 1) {
return {extent};
}
......
......@@ -48,11 +48,13 @@ TEST(MultiLevelTile, SampleSplitTwo) {
Target target = common::DefaultHostTarget();
#endif
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch));
MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
for (int i = 0; i < 100; ++i) {
size_t number_to_split = rand() % 65535 + 2; // random number in [2, 2^16]
std::vector<size_t> split = multi_level_tiling.SampleSplitTwo<size_t>(number_to_split);
std::vector<size_t> split =
multi_level_tiling.SampleSplitTwo<size_t>(number_to_split);
EXPECT_EQ(split.size(), 2UL);
EXPECT_EQ(split[0] * split[1], number_to_split);
}
......@@ -67,12 +69,14 @@ TEST(MultiLevelTile, SampleTileSplit) {
Target target = common::DefaultHostTarget();
#endif
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch));
MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
for (int i = 0; i < 100; ++i) {
int number_to_split = rand() % 65535 + 2; // random number in [2, 2^16]
int split_size = rand() % 5 + 1; // random in [1, 5]
std::vector<int> split = multi_level_tiling.SampleTileSplit<int>(number_to_split, split_size);
std::vector<int> split =
multi_level_tiling.SampleTileSplit<int>(number_to_split, split_size);
EXPECT_EQ(split.size(), static_cast<size_t>(split_size));
int product = 1;
for (int num : split) {
......@@ -102,21 +106,31 @@ TEST(MultiLevelTile, SimpleLoops) {
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_SimpleLoops", stages, {C}, {}, {}, nullptr, target, true);
lang::LowerVec("TestMultiLevelTile_SimpleLoops",
stages,
{C},
{},
{},
nullptr,
target,
true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch));
MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly();
// ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) {
......@@ -152,26 +166,30 @@ TEST(MulitLevelTile, MatrixMultiply) {
Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C");
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true);
lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {},
nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch));
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules);
MultiLevelTiling multi_level_tiling(target,
MultiLevelTiling::kConfigs.at(target.arch)); ir::IRSchedule
ir_schedule(ir::ModuleExpr({ast_expr})); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly();
// ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, "C");
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules); auto new_states =
multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
......@@ -201,16 +219,19 @@ TEST_F(TestMultiLevelTiling, Matmul) {
std::vector<int32_t> out_shape = {32, 32};
Initialize(common::DefaultNVGPUTarget());
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed);
SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch));
MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
auto new_states =
multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString();
std::string ir = GetIR(new_states[0]->ir_schedule);
std::string expected_ir = R"ROC(Expr 0 {
......@@ -332,7 +353,8 @@ TEST_F(TestMultiLevelTiling, Matmul) {
// execute and check precision
CheckResult(
GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(
matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
{X_shape, Y_shape},
......@@ -349,14 +371,17 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
Initialize(common::DefaultNVGPUTarget());
frontend::Program reduce_sum_op =
tests::OpBuilder("reduce_sum").Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}});
tests::OpBuilder("reduce_sum")
.Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}});
ir::IRSchedule ir_schedule = MakeIRSchedule(reduce_sum_op);
SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch));
// EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), RuleApplyType::kCannotApply);
MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
// EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state,
// default_output_names[0]), RuleApplyType::kCannotApply);
}
TEST_F(TestMultiLevelTiling, Pool2d) {
......@@ -374,7 +399,8 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
std::string data_format = "NCHW";
bool adaptive = false;
std::string padding_algorithm = "EXPLICIT";
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shape}},
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build(
{{"input", input_shape}},
{{"pool_type", pooling_type},
{"kernel_size", ksize},
{"stride_size", strides},
......@@ -403,7 +429,8 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
MultiLevelTiling multi_level_tiling(target_, mlt_config);
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
auto new_states =
multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString();
std::string ir = GetIR(new_states[0]->ir_schedule);
......@@ -534,9 +561,10 @@ Expr 1 {
VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision
CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(
BuildIRModule(MakeIRSchedule(pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
CheckResult(
GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(
pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
{input_shape},
......
......@@ -34,11 +34,15 @@ class SkipRule : public AutoGenRule {
std::string GetRuleName() const override;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override {
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override {
return RuleApplyType::kApply;
}
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override { return {state}; }
std::vector<SearchState> ApplyOnBlock(
SearchState state, const std::string& block_name) override {
return {state};
}
};
} // namespace auto_schedule
......
......@@ -53,7 +53,8 @@ TEST(SkipRule, Basic) {
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: ";
......@@ -69,7 +70,8 @@ TEST(SkipRule, Basic) {
// ApplyOnBlock
EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply);
std::vector<cinn::auto_schedule::SearchState> states = skip_rule.ApplyOnBlock(state, "C");
std::vector<cinn::auto_schedule::SearchState> states =
skip_rule.ApplyOnBlock(state, "C");
auto test_func = [&ast_expr](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
......@@ -100,7 +102,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: ";
......@@ -111,7 +114,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
SearchState state(ir_schedule, 0, {});
EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply);
std::vector<cinn::auto_schedule::SearchState> states = skip_rule.ApplyOnBlock(state, "C");
std::vector<cinn::auto_schedule::SearchState> states =
skip_rule.ApplyOnBlock(state, "C");
std::vector<ir::Expr> exprs = states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
......
......@@ -46,22 +46,28 @@ void TestAutoGenRuleBase::Initialize(const common::Target& target) {
backend_compier_ = backends::Compiler::Create(target);
}
ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(const frontend::Program& test_program,
ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
const frontend::Program& test_program,
utils::LinearRandomEngine::StateType rand_seed,
bool apply_manual_schedule) {
Context::Global().ResetNameId();
auto graph = std::make_shared<hlir::framework::Graph>(test_program, target_);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
LOG_IF(WARNING, graph->fusion_groups.size() > 1) << "Test Graph has more than 1 group";
auto& dtype_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
auto& shape_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
LOG_IF(WARNING, graph->fusion_groups.size() > 1)
<< "Test Graph has more than 1 group";
auto& dtype_dict =
graph->GetMutableAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);
if (apply_manual_schedule) {
lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front());
} else {
lowered_funcs_ = op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front());
lowered_funcs_ =
op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front());
}
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";
......@@ -76,14 +82,16 @@ std::string TestAutoGenRuleBase::GetIR(const ir::IRSchedule& schedule) {
const auto& exprs = schedule.GetModule().GetExprs();
std::stringstream module_stream;
for (auto i = 0; i < exprs.size(); ++i) {
module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr " << i << "\n";
module_stream << "Expr " << i << " {\n"
<< exprs.at(i) << "\n} // end Expr " << i << "\n";
}
return module_stream.str();
}
ir::Module TestAutoGenRuleBase::BuildIRModule(const ir::IRSchedule& schedule) {
auto&& updated_bodys = schedule.GetModule().GetExprs();
CHECK_EQ(lowered_funcs_.size(), updated_bodys.size()) << "associated exprs size not equal";
CHECK_EQ(lowered_funcs_.size(), updated_bodys.size())
<< "associated exprs size not equal";
ir::Module::Builder builder("test_bulder", this->target_);
for (int i = 0; i < lowered_funcs_.size(); ++i) {
......@@ -102,20 +110,24 @@ std::string TestAutoGenRuleBase::GenSourceCode(const ir::Module& ir_module) {
if (target_ == common::DefaultNVGPUTarget()) {
codegen = std::make_unique<backends::CodeGenCUDA_Dev>(this->target_);
} else {
codegen = std::make_unique<backends::CodeGenCX86>(this->target_, CodeGenCX86::Feature::AVX512);
codegen = std::make_unique<backends::CodeGenCX86>(
this->target_, CodeGenCX86::Feature::AVX512);
}
#else
codegen = std::make_unique<backends::CodeGenCX86>(this->target_, CodeGenCX86::Feature::AVX512);
codegen = std::make_unique<backends::CodeGenCX86>(
this->target_, CodeGenCX86::Feature::AVX512);
#endif
codegen->SetInlineBuiltinCodes(false);
return codegen->Compile(ir_module, CodeGenC::OutputKind::CImpl);
}
raw_func_type TestAutoGenRuleBase::GenExecutableKernel(const ir::Module& ir_module) {
raw_func_type TestAutoGenRuleBase::GenExecutableKernel(
const ir::Module& ir_module) {
auto&& func_name = lowered_funcs_.front()->name;
// Compile to machine code
backend_compier_->Build(ir_module);
auto test_func_ptr = reinterpret_cast<void (*)(void**, int32_t)>(backend_compier_->Lookup(func_name));
auto test_func_ptr = reinterpret_cast<void (*)(void**, int32_t)>(
backend_compier_->Lookup(func_name));
return test_func_ptr;
}
......@@ -138,15 +150,19 @@ void MemoryCopy(const float* src, float* dst, int numel, std::string type) {
}
}
void AddDataToScope(
Scope* scope, const common::Target& target, float* data_ptr, std::string name, const std::vector<int>& shape) {
void AddDataToScope(Scope* scope,
const common::Target& target,
float* data_ptr,
std::string name,
const std::vector<int>& shape) {
auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var);
CHECK(shape.size()) << "The size of shape can not be 0.";
Shape cinn_shape(shape);
tensor->Resize(cinn_shape);
auto* tgt_data_ptr = tensor->mutable_data<float>(target);
std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost";
std::string mem_cpy_type =
target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost";
MemoryCopy(data_ptr, tgt_data_ptr, cinn_shape.numel(), mem_cpy_type);
}
......@@ -159,16 +175,20 @@ void CheckResult(raw_func_type test_func,
const common::Target& target) {
CHECK(input_names.size()) << "The number of inputs must be greater than 0.";
CHECK(output_names.size()) << "The number of outputs must be greater than 0.";
CHECK_EQ(input_names.size(), input_shapes.size()) << "The quantity of input_names and input_shapes must be equal.";
CHECK_EQ(input_names.size(), input_shapes.size())
<< "The quantity of input_names and input_shapes must be equal.";
CHECK_EQ(output_names.size(), output_shapes.size())
<< "The quantity of output_names and output_shapes must be equal.";
// Initialize data
std::vector<float*> input_data_ptrs(input_names.size());
for (int i = 0; i < input_shapes.size(); ++i) {
int input_data_numel =
std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) { return a * b; });
input_data_ptrs[i] = reinterpret_cast<float*>(malloc(input_data_numel * sizeof(float)));
int input_data_numel = std::accumulate(
input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) {
return a * b;
});
input_data_ptrs[i] =
reinterpret_cast<float*>(malloc(input_data_numel * sizeof(float)));
for (int j = 0; j < input_data_numel; ++j) {
input_data_ptrs[i][j] = (rand() * 1.f) / RAND_MAX;
}
......@@ -177,24 +197,35 @@ void CheckResult(raw_func_type test_func,
std::vector<float*> expected_output_data_ptrs(output_names.size());
std::vector<int> output_data_numels(output_shapes.size());
for (int i = 0; i < output_shapes.size(); ++i) {
output_data_numels[i] =
std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) { return a * b; });
test_output_data_ptrs[i] = reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
output_data_numels[i] = std::accumulate(
output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) {
return a * b;
});
test_output_data_ptrs[i] =
reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(test_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
expected_output_data_ptrs[i] = reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
expected_output_data_ptrs[i] =
reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(
expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
}
auto launch_kernel_fn = [&](raw_func_type& raw_func, std::vector<float*>& output_data_ptrs) {
auto launch_kernel_fn = [&](raw_func_type& raw_func,
std::vector<float*>& output_data_ptrs) {
// Initialize scope
Scope scope;
// Initialize input data in scope.
for (int i = 0; i < input_names.size(); ++i) {
AddDataToScope(&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]);
AddDataToScope(
&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]);
}
// Initialize output data in scope.
for (int i = 0; i < output_names.size(); ++i) {
AddDataToScope(&scope, target, output_data_ptrs[i], output_names[i], output_shapes[i]);
AddDataToScope(&scope,
target,
output_data_ptrs[i],
output_names[i],
output_shapes[i]);
}
// Create Instruction and run
......@@ -208,8 +239,11 @@ void CheckResult(raw_func_type test_func,
// data
for (int i = 0; i < output_names.size(); ++i) {
const float* result_ptr = scope.GetTensor(output_names[i])->data<float>();
std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost";
MemoryCopy(result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type);
std::string mem_cpy_type = target == common::DefaultNVGPUTarget()
? "DeviceToHost"
: "HostToHost";
MemoryCopy(
result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type);
}
};
......@@ -220,7 +254,8 @@ void CheckResult(raw_func_type test_func,
// Check result
for (int i = 0; i < output_shapes.size(); ++i) {
for (int j = 0; j < output_data_numels[i]; ++j) {
ASSERT_NEAR(test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4);
ASSERT_NEAR(
test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4);
}
}
......
......@@ -47,15 +47,18 @@ class TestAutoGenRuleBase : public ::testing::Test {
// Initialize context for specified target
void Initialize(const common::Target& target);
// construct an ir::IRSchedule by lowering the specified for following AutoGenRule test
ir::IRSchedule MakeIRSchedule(const frontend::Program& test_program,
// construct an ir::IRSchedule by lowering the specified for following
// AutoGenRule test
ir::IRSchedule MakeIRSchedule(
const frontend::Program& test_program,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool apply_manual_schedule = false);
// Get the IR of bodies in IRSchedule
std::string GetIR(const ir::IRSchedule& schedule);
// build ir::Module from the original lowered funcs with their bodies updated by the schedule
// build ir::Module from the original lowered funcs with their bodies updated
// by the schedule
ir::Module BuildIRModule(const ir::IRSchedule& schedule);
// generate source code with the built ir module
......@@ -75,9 +78,12 @@ class TestAutoGenRuleBase : public ::testing::Test {
* @params-2: Expected function pointer for comparison.
* @params-3: Names of input data.
* @params-4: Names of output data.
* @params-5: Shapes of the input data, each input corresponds to a std::vector<int>.
* @params-6: Shapes of the output data, each output corresponds to a std::vector<int>.
* @params-7: The Target expressing computing platform and architecture of the function to be tested.
* @params-5: Shapes of the input data, each input corresponds to a
* std::vector<int>.
* @params-6: Shapes of the output data, each output corresponds to a
* std::vector<int>.
* @params-7: The Target expressing computing platform and architecture of the
* function to be tested.
* @return: void
*/
void CheckResult(raw_func_type test_func,
......
......@@ -26,20 +26,26 @@ namespace auto_schedule {
class SearchState;
// Select the next block to be operated for SearchState during the search process
// Select the next block to be operated for SearchState during the search
// process
class BlockSampler {
public:
/**
* @brief Create a BlockSampler with the specific strategy name and necessary construct parameters.
* @brief Create a BlockSampler with the specific strategy name and necessary
* construct parameters.
* @param all_blocks All possible blocks to be sampled.
* @param default_remove_policy The default option to determine whether to delete the next block after selecting it.
* @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The block sampling strategy.
* Currently, the available strategies are "traversal" and "probabilistic",
* where "traversal" means to select blocks one by one until all blocks are traversed,
* and "probabilistic" means randomly picking blocks according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a weight.
* Currently, the available strategies are "traversal" and
* "probabilistic", where "traversal" means to select blocks one by one until
* all blocks are traversed, and "probabilistic" means randomly picking blocks
* according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/
static std::unique_ptr<BlockSampler> Make(const std::vector<ir::Expr>& all_blocks,
static std::unique_ptr<BlockSampler> Make(
const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy = true,
const std::string& strategy = "traversal",
utils::LinearRandomEngine::StateType rand_seed = 0,
......@@ -56,18 +62,22 @@ class BlockSampler {
protected:
// A BlockSampler object should be created with the static function Make()
BlockSampler(const std::vector<ir::Expr>& all_blocks, bool default_remove_policy);
BlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy);
// Select a block to apply rule
// The param remove is used to determine whether to delete the next block after selecting it,
// If remove == true, it will not be sampled in the future.
// The param remove is used to determine whether to delete the next block
// after selecting it, If remove == true, it will not be sampled in the
// future.
virtual std::string NextBlock(bool remove) = 0;
// The names of all blocks
// Because the Block Expr will be changed in the search process, the name is saved for indexing
// Because the Block Expr will be changed in the search process, the name is
// saved for indexing
std::vector<std::string> all_blocks_;
// The default policy to determine whether to delete the next block after selecting it.
// The default policy to determine whether to delete the next block after
// selecting it.
bool default_remove_policy_;
};
......@@ -75,7 +85,8 @@ class BlockSampler {
// witch means to select blocks one by one until all blocks are traversed.
class TraversalBlockSampler : public BlockSampler {
public:
TraversalBlockSampler(const std::vector<ir::Expr>& all_blocks, bool default_remove_policy)
TraversalBlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy)
: BlockSampler(all_blocks, default_remove_policy), cur_idx_(0) {}
const char* Name() const override { return "traversal"; }
......
......@@ -24,7 +24,8 @@ namespace auto_schedule {
std::vector<ir::Expr> CreateTestBlocks() {
std::vector<ir::Expr> blocks;
for (int i = 0; i < 3; ++i) {
ir::Expr block = ir::ScheduleBlock::Make({}, {}, {}, "block_" + std::to_string(i), ir::Expr());
ir::Expr block = ir::ScheduleBlock::Make(
{}, {}, {}, "block_" + std::to_string(i), ir::Expr());
blocks.push_back(ir::ScheduleBlockRealize::Make({}, block));
}
return blocks;
......@@ -32,9 +33,11 @@ std::vector<ir::Expr> CreateTestBlocks() {
TEST(BlockSampler, Make) {
std::vector<ir::Expr> mock_blocks = CreateTestBlocks();
auto traversal_block_sampler = BlockSampler::Make(mock_blocks, true, "traversal");
auto traversal_block_sampler =
BlockSampler::Make(mock_blocks, true, "traversal");
ASSERT_STREQ(traversal_block_sampler->Name(), "traversal");
auto probabilistic_block_sampler = BlockSampler::Make(mock_blocks, true, "probabilistic");
auto probabilistic_block_sampler =
BlockSampler::Make(mock_blocks, true, "probabilistic");
ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic");
}
......@@ -55,14 +58,16 @@ TEST(TraversalBlockSampler, NextBlock) {
TEST(ProbabilisticBlockSampler, NextBlock) {
std::vector<ir::Expr> blocks = CreateTestBlocks();
auto probabilistic_block_sampler = BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1});
auto probabilistic_block_sampler =
BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1});
std::string block_name;
for (int i = 0; i < 20; ++i) {
block_name = probabilistic_block_sampler->NextBlock();
VLOG(6) << "next block name: " << block_name;
}
probabilistic_block_sampler = BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1});
probabilistic_block_sampler =
BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1});
probabilistic_block_sampler->NextBlock();
probabilistic_block_sampler->NextBlock();
probabilistic_block_sampler->NextBlock();
......
......@@ -30,16 +30,21 @@ class SearchState;
class RuleSampler {
public:
/**
* @brief Create a RuleSampler with the specific strategy name and necessary construct parameters.
* @brief Create a RuleSampler with the specific strategy name and necessary
* construct parameters.
* @param potential_rules All possible rules to be sampled.
* @param default_remove_policy The default option to determine whether to delete the next block after selecting it.
* @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The rule sampling strategy.
* Currently, the available strategies are "traversal" and "probabilistic",
* where "traversal" means to select rules one by one until all rules are traversed,
* and "probabilistic" means randomly picking rules according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a weight.
* Currently, the available strategies are "traversal" and
* "probabilistic", where "traversal" means to select rules one by one until
* all rules are traversed, and "probabilistic" means randomly picking rules
* according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/
static std::unique_ptr<RuleSampler> Make(const std::vector<AutoGenRule*>& potential_rules,
static std::unique_ptr<RuleSampler> Make(
const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy = true,
const std::string& strategy = "traversal",
utils::LinearRandomEngine::StateType rand_seed = 0,
......@@ -55,18 +60,21 @@ class RuleSampler {
protected:
// A RuleSampler object should be created with the static function Make()
RuleSampler(const std::vector<AutoGenRule*>& potential_rules, bool default_remove_policy)
: potential_rules_(&potential_rules), default_remove_policy_(default_remove_policy) {}
RuleSampler(const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy)
: potential_rules_(&potential_rules),
default_remove_policy_(default_remove_policy) {}
// Select a rule to apply.
// The param remove is used to determine whether to delete the next rule after selecting it,
// If remove == true, it will not be sampled in the future.
// The param remove is used to determine whether to delete the next rule after
// selecting it, If remove == true, it will not be sampled in the future.
virtual AutoGenRule* NextRule(bool remove) = 0;
// The pointer refers to all potential rules
const std::vector<AutoGenRule*>* potential_rules_;
// The default policy to determine whether to delete the next rule after selecting it.
// The default policy to determine whether to delete the next rule after
// selecting it.
bool default_remove_policy_;
};
......@@ -74,7 +82,8 @@ class RuleSampler {
// witch means to select rules one by one until all rules are traversed.
class TraversalRuleSampler : public RuleSampler {
public:
TraversalRuleSampler(const std::vector<AutoGenRule*>& potential_rules, bool default_remove_policy)
TraversalRuleSampler(const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy)
: RuleSampler(potential_rules, default_remove_policy), cur_idx_(0) {}
const char* Name() const override { return "traversal"; }
......
......@@ -28,13 +28,16 @@ Target target = common::DefaultNVGPUTarget();
Target target = common::DefaultHostTarget();
#endif
std::vector<AutoGenRule*> GenerateTestRules() { return {new AutoUnroll(target), new SkipRule(target)}; }
std::vector<AutoGenRule*> GenerateTestRules() {
return {new AutoUnroll(target), new SkipRule(target)};
}
TEST(RuleSampler, Make) {
std::vector<AutoGenRule*> rules = GenerateTestRules();
auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal");
ASSERT_STREQ(traversal_block_sampler->Name(), "traversal");
auto probabilistic_block_sampler = RuleSampler::Make(rules, true, "probabilistic");
auto probabilistic_block_sampler =
RuleSampler::Make(rules, true, "probabilistic");
ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic");
}
......@@ -58,14 +61,16 @@ TEST(TraversalRuleSampler, NextRule) {
TEST(ProbabilisticRuleSampler, NextRule) {
std::vector<AutoGenRule*> rules = GenerateTestRules();
auto probabilistic_rule_sampler = RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1});
auto probabilistic_rule_sampler =
RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1});
AutoGenRule* rule;
for (int i = 0; i < 20; ++i) {
rule = probabilistic_rule_sampler->NextRule();
VLOG(6) << "next rule name: " << rule->GetRuleName();
}
probabilistic_rule_sampler = RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1});
probabilistic_rule_sampler =
RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1});
probabilistic_rule_sampler->NextRule();
probabilistic_rule_sampler->NextRule();
ASSERT_EQ(nullptr, probabilistic_rule_sampler->NextRule());
......
......@@ -39,18 +39,23 @@ DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
SearchSpace::SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed)
: tune_task_(tune_task), rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) {
SearchSpace::SearchSpace(const TuneTask& tune_task,
utils::LinearRandomEngine::StateType rand_seed)
: tune_task_(tune_task),
rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) {
const auto& target = tune_task_.target;
// initialize a set of rules and they are commonly used by all states
// TODO(zhhsplendid): pass correct output names to AutoInline
// sketch_rules_.emplace_back(new AutoInline(target, tune_task_.output_names));
sketch_rules_.emplace_back(new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch)));
// sketch_rules_.emplace_back(new AutoInline(target,
// tune_task_.output_names));
sketch_rules_.emplace_back(
new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch)));
sketch_rules_.emplace_back(new AutoUnroll(target));
sketch_rules_.emplace_back(new SkipRule(target));
}
SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model) {
SearchState SearchSpace::GetScheduleMutate(const SearchState& state,
const ExprCostModel& cost_model) {
bool has_manual_schedule = false;
if (has_manual_schedule) {
SearchState ret = ManualScheduleMutate(state);
......@@ -58,9 +63,11 @@ SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprC
}
SearchState ret = RandomScheduleMutate(state);
if (FLAGS_auto_schedule_use_cost_model) {
ret->predicted_cost = cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target);
ret->predicted_cost =
cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target);
}
VLOG(4) << JoinStatesDebugString("SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5));
VLOG(4) << JoinStatesDebugString(
"SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5));
return ret;
}
......@@ -79,7 +86,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
for (int idx = 0; idx != ret->applicable_rules.size(); ++idx) {
AutoGenRule* rule = ret->applicable_rules.at(idx);
RuleApplyType apply_type = rule->Init(&ret->ir_schedule);
VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "=" << static_cast<int>(apply_type);
VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "="
<< static_cast<int>(apply_type);
apply_types[idx] = apply_type;
if (apply_type != RuleApplyType::kCannotApply) {
weight_to_rule_index[cur_weight] = idx;
......@@ -94,7 +102,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
}
// 3. Sample a schedule on the distribution
int sample_weighted_index = utils::SampleUniformInt(0, cur_weight, &rand_seed_);
int sample_weighted_index =
utils::SampleUniformInt(0, cur_weight, &rand_seed_);
auto iter = weight_to_rule_index.upper_bound(sample_weighted_index);
--iter;
......@@ -102,13 +111,15 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
int sample_rule_index = iter->second;
CHECK_LT(sample_rule_index, ret->applicable_rules.size());
AutoGenRule* sample_rule = ret->applicable_rules.at(sample_rule_index);
VLOG(7) << "Apply rule: " << sample_rule->GetRuleName() << " with index=" << sample_weighted_index - iter->first;
VLOG(7) << "Apply rule: " << sample_rule->GetRuleName()
<< " with index=" << sample_weighted_index - iter->first;
// 4. Apply the schedule change
sample_rule->Apply(sample_weighted_index - iter->first);
// 5. Remove the rule after applying it
if (apply_types.at(sample_rule_index) != RuleApplyType::kCannotApply) {
ret->applicable_rules.erase(ret->applicable_rules.begin() + sample_rule_index);
ret->applicable_rules.erase(ret->applicable_rules.begin() +
sample_rule_index);
}
return ret;
......@@ -116,17 +127,20 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
VLOG(5) << "SearchSpace::GetRandomInitialSketch with num=" << num;
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_));
std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end(), std::back_inserter(init_rules), [](const auto& rule) {
return rule.get();
});
std::transform(sketch_rules_.begin(),
sketch_rules_.end(),
std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
std::vector<SearchState> result;
while (result.size() < num) {
SearchState state(init_schedule, SearchState::NOT_INIT_COST, init_rules);
for (int i = 0; i < init_sketch_random_depth_; ++i) {
VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: " << i;
VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: "
<< i;
state = RandomScheduleMutate(state);
if (state->applicable_rules.empty()) {
break;
......@@ -134,7 +148,9 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
}
VLOG(5) << JoinStatesDebugString(
"SearchSpace::GetRandomInitialSketch-New_Sketch", {state}, /*verbose=*/VLOG_IS_ON(6));
"SearchSpace::GetRandomInitialSketch-New_Sketch",
{state},
/*verbose=*/VLOG_IS_ON(6));
result.emplace_back(std::move(state));
}
return result;
......@@ -142,15 +158,18 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
VLOG(5) << "SearchSpace::InitSketchWithRandomPrunedStrategy";
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_));
auto all_blocks = init_schedule.GetAllBlocks();
auto block_sampler = BlockSampler::Make(all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_));
auto block_sampler = BlockSampler::Make(
all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_));
std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) {
return rule.get();
});
std::transform(sketch_rules_.begin(),
sketch_rules_.end() - 1,
std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
CHECK(init_rules.size() > 0) << "number of init rules cannot be 0";
SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {});
......@@ -159,7 +178,8 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
std::vector<SearchState>* p_states_next = &states_buf2;
int total_steps = 0, steps;
std::string block_name;
while ("" != (block_name = block_sampler->NextBlock()) && total_steps < init_sketch_random_depth_) {
while ("" != (block_name = block_sampler->NextBlock()) &&
total_steps < init_sketch_random_depth_) {
steps = utils::SampleUniformInt(1, init_rules.size() + 1, &rand_seed_);
if (total_steps + steps > init_sketch_random_depth_) {
steps = init_sketch_random_depth_ - total_steps;
......@@ -167,29 +187,39 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
total_steps += steps;
p_states_next->clear();
for (const auto& state : *p_states_cur) {
auto rule_sampler = RuleSampler::Make(init_rules, true, "probabilistic", utils::ForkRandomState(&rand_seed_));
auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), steps, false, 1);
p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end());
auto rule_sampler =
RuleSampler::Make(init_rules,
true,
"probabilistic",
utils::ForkRandomState(&rand_seed_));
auto new_states = ApplySketchRule(
state, block_name, rule_sampler.get(), steps, false, 1);
p_states_next->insert(
p_states_next->end(), new_states.begin(), new_states.end());
}
std::swap(p_states_cur, p_states_next);
}
VLOG(5) << JoinStatesDebugString(
"SearchSpace::InitSketchWithRandomPrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6));
"SearchSpace::InitSketchWithRandomPrunedStrategy",
*p_states_cur,
/*verbose=*/VLOG_IS_ON(6));
return *p_states_cur;
}
std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() {
VLOG(5) << "SearchSpace::InitSketchWithRulePrunedStrategy";
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_));
auto all_blocks = init_schedule.GetAllBlocks();
std::reverse(all_blocks.begin(), all_blocks.end());
auto block_sampler = BlockSampler::Make(all_blocks, true, "traversal");
std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) {
return rule.get();
});
std::transform(sketch_rules_.begin(),
sketch_rules_.end() - 1,
std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
CHECK(init_rules.size() > 0) << "number of init rules cannot be 0";
SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {});
......@@ -201,17 +231,22 @@ std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() {
p_states_next->clear();
for (const auto& state : *p_states_cur) {
auto rule_sampler = RuleSampler::Make(init_rules, true, "traversal");
auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), 0, true);
p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end());
auto new_states =
ApplySketchRule(state, block_name, rule_sampler.get(), 0, true);
p_states_next->insert(
p_states_next->end(), new_states.begin(), new_states.end());
}
std::swap(p_states_cur, p_states_next);
}
VLOG(5) << JoinStatesDebugString(
"SearchSpace::InitSketchWithRulePrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6));
"SearchSpace::InitSketchWithRulePrunedStrategy",
*p_states_cur,
/*verbose=*/VLOG_IS_ON(6));
return *p_states_cur;
}
std::vector<SearchState> SearchSpace::GenerateSketches(int num, const std::string& strategy) {
std::vector<SearchState> SearchSpace::GenerateSketches(
int num, const std::string& strategy) {
VLOG(4) << "SearchSpace::GenerateSketches with num = " << num;
if (strategy == "random") {
......@@ -239,11 +274,13 @@ std::vector<SearchState> SearchSpace::GenerateSketches(int num, const std::strin
}
}
}
VLOG(4) << JoinStatesDebugString("SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5));
VLOG(4) << JoinStatesDebugString(
"SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5));
return result;
}
std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
std::vector<SearchState> SearchSpace::ApplySketchRule(
const SearchState& state,
const std::string& block_name,
RuleSampler* rule_sampler,
int steps,
......@@ -252,15 +289,18 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
std::list<SearchState> layer{state};
int step = 0;
AutoGenRule* rule;
// After determining a SearchState and a block, each rule has two possibilities: apply and not apply.
// In all transfer spaces, select a rule at each step, and collect all possible new states arrived by apply and not
// apply. This forms a tree, and we can use rule pruning or random pruning to reduce the number of sketches.
// After determining a SearchState and a block, each rule has two
// possibilities: apply and not apply. In all transfer spaces, select a rule
// at each step, and collect all possible new states arrived by apply and not
// apply. This forms a tree, and we can use rule pruning or random pruning to
// reduce the number of sketches.
VLOG(6) << "Collect the states of all transfers within steps: " << steps;
while ((step++ < steps || steps == 0) && (rule = rule_sampler->NextRule())) {
VLOG(7) << "step = " << step << ", rule: " << rule->GetRuleName();
std::list<SearchState> new_states;
int id = 0;
for (std::list<SearchState>::iterator iter = layer.begin(); iter != layer.end();) {
for (std::list<SearchState>::iterator iter = layer.begin();
iter != layer.end();) {
// Some rules will reduce the number of blocks, such as AutoInline,
// so we need to check whether the SearchState still has the block.
if (!(*iter)->ir_schedule.HasBlock(block_name)) {
......@@ -268,21 +308,26 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
continue;
}
auto type = rule->AnalyseApplyType(*iter, block_name);
VLOG(7) << "At SearchState " << ++id
<< ", apply type = " << static_cast<typename std::underlying_type<RuleApplyType>::type>(type);
VLOG(7)
<< "At SearchState " << ++id << ", apply type = "
<< static_cast<typename std::underlying_type<RuleApplyType>::type>(
type);
// if cannot apply the rule, skip it
if (type == RuleApplyType::kCannotApply) {
++iter;
continue;
}
// if can apply the rule, apply it and determine whether to prune the branch that do not apply
std::vector<SearchState> tmp_states = rule->ApplyOnBlock(*iter, block_name);
// if can apply the rule, apply it and determine whether to prune the
// branch that do not apply
std::vector<SearchState> tmp_states =
rule->ApplyOnBlock(*iter, block_name);
new_states.insert(new_states.end(), tmp_states.begin(), tmp_states.end());
bool need_prune = false;
if (prune_by_rule) {
need_prune = (type == RuleApplyType::kApplyAndPruneOtherRules);
} else {
need_prune = (utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability);
need_prune =
(utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability);
}
if (need_prune) {
iter = layer.erase(iter);
......@@ -290,10 +335,12 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
++iter;
}
}
VLOG(7) << "apply on block: " << block_name << ", generate " << new_states.size() << " new states at step " << step;
VLOG(7) << "apply on block: " << block_name << ", generate "
<< new_states.size() << " new states at step " << step;
layer.splice(layer.end(), std::move(new_states));
}
VLOG(6) << "apply on block: " << block_name << ", generate " << layer.size() - 1 << " more states at all";
VLOG(6) << "apply on block: " << block_name << ", generate "
<< layer.size() - 1 << " more states at all";
return std::vector<SearchState>(layer.begin(), layer.end());
}
......
......@@ -40,24 +40,31 @@ namespace auto_schedule {
*/
class SearchSpace {
public:
SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed = -1);
SearchSpace(const TuneTask& tune_task,
utils::LinearRandomEngine::StateType rand_seed = -1);
// Sketch mutate, returns the mutated ModuleExpr and estimited cost
virtual SearchState GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model);
virtual SearchState GetScheduleMutate(const SearchState& state,
const ExprCostModel& cost_model);
/**
* \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate.
* @param strategy The strategy to generate sketchs,
* Current optional strategies are "rule_prune" or "random_prune" or "random".
* - "rule_prune": will use rules to prune and generate sketches as efficiently as possible.
* - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches,
* and supports the function of a rule returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat this step several times,
* however, each rule can only be used on one SearchState at most once.
* Current optional strategies are "rule_prune" or "random_prune" or
* "random".
* - "rule_prune": will use rules to prune and generate sketches as
* efficiently as possible.
* - "random_prune": will use the new interface ApplySketchRules() to simulate
* the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketchs.
*/
virtual std::vector<SearchState> GenerateSketches(int num, const std::string& strategy);
virtual std::vector<SearchState> GenerateSketches(
int num, const std::string& strategy);
private:
// TODO(zhhsplendid): mutate by manual schedule.
......@@ -69,20 +76,24 @@ class SearchSpace {
// Generate num sketchs, each with several rounds of SketchMutate
std::vector<SearchState> InitSketchWithRandomStrategy(int num);
// Generate sketch pruned randomly as initial population of evolutionary search
// Generate sketch pruned randomly as initial population of evolutionary
// search
std::vector<SearchState> InitSketchWithRandomPrunedStrategy();
// Generate sketch pruned by rules as initial population of evolutionary search
// Generate sketch pruned by rules as initial population of evolutionary
// search
std::vector<SearchState> InitSketchWithRulePrunedStrategy();
/**
* @brief Collect the new states that may be transferred to after applying several rules on a block from a certain
* state.
* @brief Collect the new states that may be transferred to after applying
* several rules on a block from a certain state.
* @param state Starting point of state transition.
* @param block_name Name of the block to apply the rules to.
* @param rule_sampler Sampler that samples the new rule to apply on the block.
* @param rule_sampler Sampler that samples the new rule to apply on the
* block.
* @param steps Number of steps to apply the rule.
* @param prune_by_rule If true, prune the state transition tree by rule, otherwise prune randomly.
* @param prune_by_rule If true, prune the state transition tree by rule,
* otherwise prune randomly.
* @param prune_probability Pruning probability of random pruning.
*/
std::vector<SearchState> ApplySketchRule(const SearchState& state,
......
......@@ -29,7 +29,9 @@
namespace cinn {
namespace auto_schedule {
SearchState::SearchState(ir::IRSchedule ir_sch, float cost, const std::vector<AutoGenRule*>& rules)
SearchState::SearchState(ir::IRSchedule ir_sch,
float cost,
const std::vector<AutoGenRule*>& rules)
: common::Shared<_SearchState_>(common::make_shared<_SearchState_>()) {
auto* state = get();
state->ir_schedule = std::move(ir_sch);
......@@ -37,13 +39,16 @@ SearchState::SearchState(ir::IRSchedule ir_sch, float cost, const std::vector<Au
state->predicted_cost = cost;
}
SearchState SearchState::Copy() const { return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {}); }
SearchState SearchState::Copy() const {
return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {});
}
std::string _SearchState_::DebugString() const {
const auto& exprs = ir_schedule.GetModule().GetExprs();
std::stringstream module_stream;
for (auto i = 0; i < exprs.size(); ++i) {
module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr";
module_stream << "Expr " << i << " {\n"
<< exprs.at(i) << "\n} // end Expr";
}
const char* fmt_str = R"ROC(
......@@ -55,8 +60,10 @@ ScheduleDesc {
} // end ScheduleDesc
predicted_cost: %f)ROC";
return utils::StringFormat(
fmt_str, module_stream.str().c_str(), ir_schedule.GetTraceDesc().DebugString().c_str(), predicted_cost);
return utils::StringFormat(fmt_str,
module_stream.str().c_str(),
ir_schedule.GetTraceDesc().DebugString().c_str(),
predicted_cost);
}
bool operator<(const SearchState& left, const SearchState& right) {
......@@ -119,7 +126,8 @@ size_t SearchStateHash::operator()(const SearchState& s) const {
return hash_key;
}
bool SearchStateEqual::operator()(const SearchState& lhs, const SearchState& rhs) const {
bool SearchStateEqual::operator()(const SearchState& lhs,
const SearchState& rhs) const {
const auto& lhs_exprs = lhs->ir_schedule.GetModule().GetExprs();
const auto& rhs_exprs = rhs->ir_schedule.GetModule().GetExprs();
// compare exprs size firstly
......@@ -127,20 +135,24 @@ bool SearchStateEqual::operator()(const SearchState& lhs, const SearchState& rhs
// compare every expr one by one with ir::IrEqualVisitor
for (int i = 0; i < lhs_exprs.size(); ++i) {
ir::IrEqualVisitor compartor(/*allow_name_suffix_diff=*/true); // ignore suffix difference in name
ir::IrEqualVisitor compartor(
/*allow_name_suffix_diff=*/true); // ignore suffix difference in name
if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false;
}
return true;
}
std::string JoinStatesDebugString(const std::string& title, const std::vector<SearchState>& states, bool verbose) {
std::string JoinStatesDebugString(const std::string& title,
const std::vector<SearchState>& states,
bool verbose) {
std::stringstream ss;
ss << title << " states size:" << states.size() << "\n";
SearchStateHash state_hasher;
for (size_t i = 0; i < states.size(); ++i) {
uint64_t hash_key = state_hasher(states[i]);
if (verbose) {
ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>" << states[i]->DebugString() << "\n<------";
ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>"
<< states[i]->DebugString() << "\n<------";
} else {
ss << "\tState-" << i << " hash:" << hash_key << "\n";
}
......
......@@ -35,7 +35,9 @@ class SearchState : public common::Shared<_SearchState_> {
public:
SearchState() = default;
// create a new SearchState
explicit SearchState(ir::IRSchedule ir_sch, float cost = NOT_INIT_COST, const std::vector<AutoGenRule*>& rules = {});
explicit SearchState(ir::IRSchedule ir_sch,
float cost = NOT_INIT_COST,
const std::vector<AutoGenRule*>& rules = {});
// Constant standing for a cost not being initialized
static constexpr float NOT_INIT_COST = std::numeric_limits<float>::max();
......@@ -62,12 +64,14 @@ struct _SearchState_ : public common::Object {
static constexpr char* __type_info__ = "auto_schedule_state";
};
// SearchStateHash hash functor that visits every AST node and combine their hash of node_type in dfs order
// SearchStateHash hash functor that visits every AST node and combine their
// hash of node_type in dfs order
struct SearchStateHash {
size_t operator()(const SearchState& s) const;
};
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST struct and fields
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST
// struct and fields
struct SearchStateEqual {
bool operator()(const SearchState& lhs, const SearchState& rhs) const;
};
......
......@@ -36,15 +36,34 @@ TEST(TestSearchState, SearchStateHash_Equal) {
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_1 =
lang::LowerVec("A_plus_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
auto a_plus_const_funcs_1 = lang::LowerVec("A_plus_const",
poly::CreateStages({A, B}),
{A, B},
{},
{},
nullptr,
target,
true);
cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_2 =
lang::LowerVec("A_plus_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
auto a_plus_const_funcs_2 = lang::LowerVec("A_plus_const",
poly::CreateStages({A, B}),
{A, B},
{},
{},
nullptr,
target,
true);
cinn::common::Context::Global().ResetNameId();
auto a_plus_b_funcs = lang::LowerVec("A_plus_B", poly::CreateStages({A, C}), {A, C}, {}, {}, nullptr, target, true);
auto a_plus_b_funcs = lang::LowerVec("A_plus_B",
poly::CreateStages({A, C}),
{A, C},
{},
{},
nullptr,
target,
true);
std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B)
{
......@@ -114,19 +133,25 @@ TEST(TestSearchState, SearchStateHash_Equal) {
})ROC";
ASSERT_EQ(a_plus_const_funcs_1.size(), 1);
EXPECT_EQ(a_plus_const_funcs_1_str, utils::GetStreamCnt(a_plus_const_funcs_1.front()));
EXPECT_EQ(a_plus_const_funcs_1_str,
utils::GetStreamCnt(a_plus_const_funcs_1.front()));
ASSERT_EQ(a_plus_const_funcs_2.size(), 1);
EXPECT_EQ(a_plus_const_funcs_2_str, utils::GetStreamCnt(a_plus_const_funcs_2.front()));
EXPECT_EQ(a_plus_const_funcs_2_str,
utils::GetStreamCnt(a_plus_const_funcs_2.front()));
ASSERT_EQ(a_plus_b_funcs.size(), 1);
EXPECT_EQ(a_plus_b_funcs_str, utils::GetStreamCnt(a_plus_b_funcs.front()));
SearchState a_plus_const_state1(ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_1.front()->body})));
SearchState a_plus_const_state2(ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_2.front()->body})));
SearchState a_plus_b_state(ir::IRSchedule(ir::ModuleExpr({a_plus_b_funcs.front()->body})));
SearchState a_plus_const_state1(
ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_1.front()->body})));
SearchState a_plus_const_state2(
ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_2.front()->body})));
SearchState a_plus_b_state(
ir::IRSchedule(ir::ModuleExpr({a_plus_b_funcs.front()->body})));
SearchStateHash hash_functor;
SearchStateEqual equal_functor;
ASSERT_EQ(hash_functor(a_plus_const_state1), hash_functor(a_plus_const_state2));
ASSERT_EQ(hash_functor(a_plus_const_state1),
hash_functor(a_plus_const_state2));
ASSERT_TRUE(equal_functor(a_plus_const_state1, a_plus_const_state2));
ASSERT_NE(hash_functor(a_plus_const_state1), hash_functor(a_plus_b_state));
ASSERT_FALSE(equal_functor(a_plus_const_state1, a_plus_b_state));
......
......@@ -41,7 +41,8 @@ DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn {
namespace auto_schedule {
EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task,
EvolutionarySearch::EvolutionarySearch(
const TuneTask& tune_task,
const ExprCostModel& cost_model,
Database* database,
utils::LinearRandomEngine::StateType rand_seed,
......@@ -51,7 +52,8 @@ EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task,
database_(database),
rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)),
mutators_(mutate_rules) {
search_space_ = std::make_unique<SearchSpace>(tune_task, utils::ForkRandomState(&rand_seed_));
search_space_ = std::make_unique<SearchSpace>(
tune_task, utils::ForkRandomState(&rand_seed_));
if (mutators_.empty()) {
mutators_.push_back(std::make_tuple("mutate_tile_size", 1.0));
}
......@@ -59,7 +61,8 @@ EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task,
for (const auto& mutator : mutators_) {
if (std::get<1>(mutator) > 0) {
accum_weight += std::get<1>(mutator);
weighted_mutators_.insert(std::make_pair(accum_weight, MutateRule::Make(std::get<0>(mutator))));
weighted_mutators_.insert(
std::make_pair(accum_weight, MutateRule::Make(std::get<0>(mutator))));
}
}
......@@ -72,46 +75,66 @@ SearchState EvolutionarySearch::SearchModuleExpr(const TuningOptions& options) {
return SearchModuleExprBests(options)[0];
}
std::vector<SearchState> EvolutionarySearch::SearchModuleExprBests(const TuningOptions& options) {
VLOG(4) << "start SearchModuleExprBests with initial statistics: visited_candidates size="
std::vector<SearchState> EvolutionarySearch::SearchModuleExprBests(
const TuningOptions& options) {
VLOG(4) << "start SearchModuleExprBests with initial statistics: "
"visited_candidates size="
<< visited_candidates_.size();
std::vector<SearchState> init_population;
std::vector<SearchState> topk_from_database = GetTopKCandidatesFromDatabase(options.evolution_pick_database_topk);
std::vector<SearchState> topk_from_database =
GetTopKCandidatesFromDatabase(options.evolution_pick_database_topk);
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::GetTopKCandidatesFromDatabase", topk_from_database, /*verbose=*/VLOG_IS_ON(5));
int init_num = options.evolution_init_population_num - topk_from_database.size();
"EvolutionarySearch::GetTopKCandidatesFromDatabase",
topk_from_database,
/*verbose=*/VLOG_IS_ON(5));
int init_num =
options.evolution_init_population_num - topk_from_database.size();
std::vector<SearchState> init_sketch = InitSketch(init_num, "rule_prune");
VLOG(4) << JoinStatesDebugString("EvolutionarySearch::InitSketch", init_sketch, /*verbose=*/VLOG_IS_ON(5));
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::InitSketch", init_sketch, /*verbose=*/VLOG_IS_ON(5));
init_population.insert(init_population.end(), topk_from_database.begin(), topk_from_database.end());
init_population.insert(init_population.end(), init_sketch.begin(), init_sketch.end());
init_population.insert(init_population.end(),
topk_from_database.begin(),
topk_from_database.end());
init_population.insert(
init_population.end(), init_sketch.begin(), init_sketch.end());
std::vector<SearchState> picked_bests =
Evolve(init_population, options.evolution_cross_over_num, options.num_samples_per_iteration);
VLOG(4) << JoinStatesDebugString("EvolutionarySearch::Evolve", picked_bests, /*verbose=*/VLOG_IS_ON(5));
Evolve(init_population,
options.evolution_cross_over_num,
options.num_samples_per_iteration);
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve", picked_bests, /*verbose=*/VLOG_IS_ON(5));
return picked_bests;
}
std::vector<SearchState> EvolutionarySearch::SearchModuleExprEpsGreedy(const TuningOptions& options) {
std::vector<SearchState> EvolutionarySearch::SearchModuleExprEpsGreedy(
const TuningOptions& options) {
std::vector<SearchState> picked_bests = SearchModuleExprBests(options);
int random_num = options.evolution_init_population_num - options.evolution_pick_database_topk;
auto results = PickNextGenerationEpsGreedy(picked_bests,
int random_num = options.evolution_init_population_num -
options.evolution_pick_database_topk;
auto results =
PickNextGenerationEpsGreedy(picked_bests,
InitSketch(random_num, "random_prune"),
options.num_samples_per_iteration,
options.evolution_eps_greedy);
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy", results, /*verbose=*/VLOG_IS_ON(5));
"EvolutionarySearch::PickNextGenerationEpsGreedy",
results,
/*verbose=*/VLOG_IS_ON(5));
return results;
}
std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(int topk) {
std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
int topk) {
std::vector<SearchState> results;
const auto& task_key = tune_task_.serialized_key;
auto records = database_->GetTopK(task_key, topk);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto&& record : records) {
ir::IRSchedule ir_sch(optim::IRCopy(task_registry->Get(task_key)->module_expr),
ir::IRSchedule ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(&rand_seed_));
ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch);
results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost));
......@@ -119,7 +142,8 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(int t
return results;
}
void ApplyPostScheduleRules(ir::IRSchedule* schedule,
void ApplyPostScheduleRules(
ir::IRSchedule* schedule,
const std::vector<std::unique_ptr<PostScheduleRule>>& post_schedule_rules) {
schedule->TagPostSchedule();
for (const auto& post_rule : post_schedule_rules) {
......@@ -127,25 +151,33 @@ void ApplyPostScheduleRules(ir::IRSchedule* schedule,
}
}
std::vector<SearchState> EvolutionarySearch::InitSketch(int num, const std::string& strategy) {
std::vector<SearchState> EvolutionarySearch::InitSketch(
int num, const std::string& strategy) {
VLOG(4) << "InitSketch with num:" << num << ", strategy: " << strategy;
std::vector<SearchState> states = search_space_->GenerateSketches(num, strategy);
std::vector<SearchState> states =
search_space_->GenerateSketches(num, strategy);
auto post_schedule_fn = [this, &states](int index) {
ApplyPostScheduleRules(&states[index]->ir_schedule, post_schedule_rules_);
};
utils::parallel_run(post_schedule_fn, utils::SequenceDispatcher(0, states.size()), states.size());
utils::parallel_run(post_schedule_fn,
utils::SequenceDispatcher(0, states.size()),
states.size());
return states;
}
SearchState EvolutionarySearch::CrossOver(const SearchState& state1, const SearchState& state2) {
SearchState EvolutionarySearch::CrossOver(const SearchState& state1,
const SearchState& state2) {
// TODO(CtfGo): tracing CrossOver with IRSchedule
std::vector<ir::Expr> cross_over_exprs;
std::vector<ir::Expr> father_exprs = state1->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> mother_exprs = state2->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> father_exprs =
state1->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> mother_exprs =
state2->ir_schedule.GetModule().GetExprs();
CHECK_EQ(father_exprs.size(), mother_exprs.size())
<< "CrossOver ModuleExpr in EvolutionarySearch must have same number of AST";
<< "CrossOver ModuleExpr in EvolutionarySearch must have same number of "
"AST";
for (size_t i = 0; i < father_exprs.size(); ++i) {
if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) {
......@@ -154,16 +186,22 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1, const Searc
cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i]));
}
}
auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs), utils::ForkRandomState(&rand_seed_)));
auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs),
utils::ForkRandomState(&rand_seed_)));
if (FLAGS_auto_schedule_use_cost_model) {
res->predicted_cost = cost_model_.Predict(res->ir_schedule.GetModule(), tune_task_.target);
res->predicted_cost =
cost_model_.Predict(res->ir_schedule.GetModule(), tune_task_.target);
}
VLOG(5) << JoinStatesDebugString("EvolutionarySearch::CrossOver", {state1, state2, res}, /*verbose=*/VLOG_IS_ON(6));
VLOG(5) << JoinStatesDebugString("EvolutionarySearch::CrossOver",
{state1, state2, res},
/*verbose=*/VLOG_IS_ON(6));
return res;
}
SearchState EvolutionarySearch::Mutate(const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed) {
CHECK_GT(weighted_mutators_.size(), 0) << "There is no mutate rule can be applied.";
SearchState EvolutionarySearch::Mutate(
const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed) {
CHECK_GT(weighted_mutators_.size(), 0)
<< "There is no mutate rule can be applied.";
double accu_weight = (weighted_mutators_.rbegin())->first;
CHECK_GT(accu_weight, 0) << "The accumulate weight must be greater than 0.";
// sample a mutate rule
......@@ -174,24 +212,31 @@ SearchState EvolutionarySearch::Mutate(const SearchState& state, utils::LinearRa
// apply mutation on the trace of SearchState
auto trace = state->ir_schedule.GetTraceDesc();
auto new_trace = mutator->Apply(trace, rand_seed);
// replay the mutated trace on original ModuleExpr to generate a new ir_schedule
// replay the mutated trace on original ModuleExpr to generate a new
// ir_schedule
const auto& task_key = tune_task_.serialized_key;
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
ir::IRSchedule new_ir_sch(optim::IRCopy(task_registry->Get(task_key)->module_expr),
ir::IRSchedule new_ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(rand_seed));
new_trace.Replay(&new_ir_sch, true);
ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_);
auto res = SearchState(std::move(new_ir_sch));
VLOG(5) << JoinStatesDebugString("EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6));
VLOG(5) << JoinStatesDebugString(
"EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6));
return res;
}
std::vector<SearchState> EvolutionarySearch::Evolve(const std::vector<SearchState>& population,
std::vector<SearchState> EvolutionarySearch::Evolve(
const std::vector<SearchState>& population,
int cross_over_num,
int ret_num) {
VLOG(4) << utils::StringFormat(
"Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu", population.size(), cross_over_num, ret_num);
"Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu",
population.size(),
cross_over_num,
ret_num);
int generation_num = population.size();
if (generation_num == 0) {
return std::vector<SearchState>();
......@@ -199,40 +244,56 @@ std::vector<SearchState> EvolutionarySearch::Evolve(const std::vector<SearchStat
// init evolution
std::vector<SearchState> evolution(population);
for (SearchState& search_state : evolution) {
if (search_state->predicted_cost == SearchState::NOT_INIT_COST && FLAGS_auto_schedule_use_cost_model) {
search_state->predicted_cost = cost_model_.Predict(search_state->ir_schedule.GetModule(), tune_task_.target);
if (search_state->predicted_cost == SearchState::NOT_INIT_COST &&
FLAGS_auto_schedule_use_cost_model) {
search_state->predicted_cost = cost_model_.Predict(
search_state->ir_schedule.GetModule(), tune_task_.target);
}
}
VLOG(4) << JoinStatesDebugString("EvolutionarySearch::Evolve: Init evolution:", evolution, /*verbose=*/VLOG_IS_ON(5));
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: Init evolution:",
evolution,
/*verbose=*/VLOG_IS_ON(5));
// cross over
for (int i = 0; i < cross_over_num; ++i) {
int first_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_);
int second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_);
int first_rand_idx =
utils::SampleUniformInt(0, generation_num, &rand_seed_);
int second_rand_idx =
utils::SampleUniformInt(0, generation_num, &rand_seed_);
while (first_rand_idx == second_rand_idx) {
second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_);
}
evolution.push_back(CrossOver(population[first_rand_idx], population[second_rand_idx]));
evolution.push_back(
CrossOver(population[first_rand_idx], population[second_rand_idx]));
}
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: after CrossOver evolution:", evolution, /*verbose=*/VLOG_IS_ON(5));
"EvolutionarySearch::Evolve: after CrossOver evolution:",
evolution,
/*verbose=*/VLOG_IS_ON(5));
// mutate
std::vector<SearchState> mutated_individuals(evolution.size());
std::vector<utils::LinearRandomEngine::StateType> rand_seeds(evolution.size());
std::vector<utils::LinearRandomEngine::StateType> rand_seeds(
evolution.size());
for (int i = 0; i < rand_seeds.size(); ++i) {
rand_seeds[i] = utils::ForkRandomState(&rand_seed_);
}
auto mutate_fn = [this, &evolution, &mutated_individuals, &rand_seeds](int index) {
auto mutate_fn = [this, &evolution, &mutated_individuals, &rand_seeds](
int index) {
mutated_individuals[index] = Mutate(evolution[index], &rand_seeds[index]);
};
utils::parallel_run(mutate_fn, utils::SequenceDispatcher(0, evolution.size()), evolution.size());
utils::parallel_run(mutate_fn,
utils::SequenceDispatcher(0, evolution.size()),
evolution.size());
if (FLAGS_auto_schedule_use_cost_model) {
for (size_t i = 0; i < mutated_individuals.size(); ++i) {
mutated_individuals[i]->predicted_cost =
cost_model_.Predict(mutated_individuals[i]->ir_schedule.GetModule(), tune_task_.target);
mutated_individuals[i]->predicted_cost = cost_model_.Predict(
mutated_individuals[i]->ir_schedule.GetModule(), tune_task_.target);
}
}
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: mutated individuals:", mutated_individuals, /*verbose=*/VLOG_IS_ON(5));
"EvolutionarySearch::Evolve: mutated individuals:",
mutated_individuals,
/*verbose=*/VLOG_IS_ON(5));
// select top ret_num with predicted cost
utils::SizedMultiSet<SearchState> evolution_with_cost(ret_num);
for (size_t i = 0; i < evolution.size(); ++i) {
......@@ -241,14 +302,18 @@ std::vector<SearchState> EvolutionarySearch::Evolve(const std::vector<SearchStat
for (size_t i = 0; i < mutated_individuals.size(); ++i) {
evolution_with_cost.Push(mutated_individuals[i]);
}
auto selected_individuals = evolution_with_cost.ReturnAsContainer<std::vector<SearchState>>();
auto selected_individuals =
evolution_with_cost.ReturnAsContainer<std::vector<SearchState>>();
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: selected individuals:", selected_individuals, /*verbose=*/VLOG_IS_ON(5));
"EvolutionarySearch::Evolve: selected individuals:",
selected_individuals,
/*verbose=*/VLOG_IS_ON(5));
return selected_individuals;
}
std::vector<SearchState> EvolutionarySearch::PickNextGenerationEpsGreedy(const std::vector<SearchState>& picked_bests,
std::vector<SearchState> EvolutionarySearch::PickNextGenerationEpsGreedy(
const std::vector<SearchState>& picked_bests,
const std::vector<SearchState>& random_init,
int num,
float eps_greedy) {
......@@ -276,18 +341,23 @@ std::vector<SearchState> EvolutionarySearch::PickNextGenerationEpsGreedy(const s
if (!visited_candidates_.count(selected)) { // deduplicate
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy-Selected", {selected}, /*verbose=*/VLOG_IS_ON(5));
"EvolutionarySearch::PickNextGenerationEpsGreedy-Selected",
{selected},
/*verbose=*/VLOG_IS_ON(5));
visited_candidates_.insert(selected);
result.push_back(selected);
} else {
++deduplicated_cnt;
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated", {selected}, /*verbose=*/VLOG_IS_ON(5));
"EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated",
{selected},
/*verbose=*/VLOG_IS_ON(5));
}
}
VLOG(4) << utils::StringFormat(
"PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init size=%lu,num=%d,"
"PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init "
"size=%lu,num=%d,"
"eps_greedy=%f,deduplicated_cnt=%d,result size=%lu",
picked_bests.size(),
random_init.size(),
......
......@@ -41,7 +41,8 @@ class EvolutionarySearch {
* @param tune_task: the TuneTask this class works on. This class doesn't
* take ownership of the pointer.
*/
EvolutionarySearch(const TuneTask& tune_task,
EvolutionarySearch(
const TuneTask& tune_task,
const ExprCostModel& cost_model,
Database* database,
utils::LinearRandomEngine::StateType rand_seed = -1,
......@@ -55,14 +56,16 @@ class EvolutionarySearch {
/**
* Run the evolutionary search for one iteration.
*
* @return SearchState containing the best ir::ModuleExpr searched in this iteration
* @return SearchState containing the best ir::ModuleExpr searched in this
* iteration
*/
SearchState SearchModuleExpr(const TuningOptions& options);
/**
* Run the evolutionary search for one iteration.
*
* @return SearchState(s) containing best ir::ModuleExpr(s) searched in this iteration
* @return SearchState(s) containing best ir::ModuleExpr(s) searched in this
* iteration
*/
std::vector<SearchState> SearchModuleExprBests(const TuningOptions& options);
......@@ -77,7 +80,8 @@ class EvolutionarySearch {
* "eps * total_return_size" random samples and
* "(1 - eps) * total_return_size" best searched samples.
*/
std::vector<SearchState> SearchModuleExprEpsGreedy(const TuningOptions& options);
std::vector<SearchState> SearchModuleExprEpsGreedy(
const TuningOptions& options);
#ifdef CINN_WITH_TEST
/**
......@@ -87,13 +91,23 @@ class EvolutionarySearch {
* @param search_space: the mock search space, note that EvolutionarySearch
* takes the ownership.
*/
void SetSearchSpace(SearchSpace* search_space) { search_space_.reset(search_space); }
void SetSearchSpace(SearchSpace* search_space) {
search_space_.reset(search_space);
}
// Method only be called during testing, it is a wrapper of private method InitSketch().
std::vector<SearchState> TestInitSketch(int num, const std::string& strategy) { return InitSketch(num, strategy); }
// Method only be called during testing, it is a wrapper of private method
// InitSketch().
std::vector<SearchState> TestInitSketch(int num,
const std::string& strategy) {
return InitSketch(num, strategy);
}
// Method only be called during testing, it is a wrapper of private method Evolve().
std::vector<SearchState> TestEvolve(const std::vector<SearchState>& population, int cross_over_num, int ret_num) {
// Method only be called during testing, it is a wrapper of private method
// Evolve().
std::vector<SearchState> TestEvolve(
const std::vector<SearchState>& population,
int cross_over_num,
int ret_num) {
return Evolve(population, cross_over_num, ret_num);
}
#endif
......@@ -105,23 +119,31 @@ class EvolutionarySearch {
* \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate.
* @param strategy The strategy to generate sketches,
* Current optional strategies are "rule_prune" or "random_prune" or "random".
* - "rule_prune": will use rules to prune and generate sketches as efficiently as possible.
* - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches,
* and supports the function of a rule returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat this step several times,
* however, each rule can only be used on one SearchState at most once.
* Current optional strategies are "rule_prune" or "random_prune" or
* "random".
* - "rule_prune": will use rules to prune and generate sketches as
* efficiently as possible.
* - "random_prune": will use the new interface ApplySketchRules() to simulate
* the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketches.
*/
std::vector<SearchState> InitSketch(int num, const std::string& strategy);
SearchState Mutate(const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed);
SearchState Mutate(const SearchState& state,
utils::LinearRandomEngine::StateType* rand_seed);
SearchState CrossOver(const SearchState& state1, const SearchState& state2);
std::vector<SearchState> Evolve(const std::vector<SearchState>& population, int cross_over_num, int ret_num);
std::vector<SearchState> Evolve(const std::vector<SearchState>& population,
int cross_over_num,
int ret_num);
std::vector<SearchState> PickNextGenerationEpsGreedy(const std::vector<SearchState>& population,
std::vector<SearchState> PickNextGenerationEpsGreedy(
const std::vector<SearchState>& population,
const std::vector<SearchState>& random_init,
int num,
float eps_greedy);
......@@ -132,7 +154,8 @@ class EvolutionarySearch {
const ExprCostModel& cost_model_; // not owned
Database* database_; // not owned
// used to duplicate states with the same structural IR
std::unordered_set<SearchState, SearchStateHash, SearchStateEqual> visited_candidates_;
std::unordered_set<SearchState, SearchStateHash, SearchStateEqual>
visited_candidates_;
// mutate rule names and their weights
std::vector<std::tuple<std::string, double>> mutators_;
// mutate rules, the key is the accumulate weight of each mutate rule
......
......@@ -34,17 +34,23 @@
namespace cinn {
namespace auto_schedule {
std::vector<TuneTask> CreateTasks(const frontend::Program& program, const Target& target) {
std::vector<TuneTask> CreateTasks(const frontend::Program& program,
const Target& target) {
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
TaskCreator task_creator;
auto tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype");
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target);
const auto& dtype_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto i = 0; i < tasks.size(); ++i) {
tasks[i].Initialize(shape_dict, dtype_dict, op_lowerer.get());
task_registry->Regist(tasks[i].serialized_key, ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs()));
task_registry->Regist(tasks[i].serialized_key,
ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs()));
}
return tasks;
}
......@@ -64,7 +70,8 @@ class MockSearchSpace : public SearchSpace {
int GetModuleExprSize() const { return module_expr_size_; }
std::vector<SearchState> GenerateSketches(int num, const std::string& strategy) override {
std::vector<SearchState> GenerateSketches(
int num, const std::string& strategy) override {
std::vector<SearchState> ret;
for (int i = 0; i < num; ++i) {
std::vector<ir::Expr> exprs;
......@@ -83,7 +90,8 @@ class MockSearchSpace : public SearchSpace {
};
class MockCostModel : public ExprCostModel {
float Predict(const ir::ModuleExpr& sample, const common::Target& target) const override {
float Predict(const ir::ModuleExpr& sample,
const common::Target& target) const override {
float cost = 0.0f;
std::vector<ir::Expr> exprs = sample.GetExprs();
for (const ir::Expr& expr : exprs) {
......@@ -100,7 +108,8 @@ TEST(EvolutionarySearch, GetOneBest) {
mock_tune_task.serialized_key = "mock_task";
mock_tune_task.target = common::DefaultTarget();
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(mock_tune_task.serialized_key, ir::ModuleExpr({ir::Expr(0)}));
task_registry->Regist(mock_tune_task.serialized_key,
ir::ModuleExpr({ir::Expr(0)}));
MockCostModel cost_model;
TuningOptions options;
Database db(2);
......@@ -122,7 +131,8 @@ TEST(EvolutionarySearch, GetEpsGreedy) {
mock_tune_task.serialized_key = "mock_task";
mock_tune_task.target = common::DefaultTarget();
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(mock_tune_task.serialized_key, ir::ModuleExpr({ir::Expr(0)}));
task_registry->Regist(mock_tune_task.serialized_key,
ir::ModuleExpr({ir::Expr(0)}));
ExprCostModel cost_model;
TuningOptions options;
Database db(2);
......@@ -131,10 +141,12 @@ TEST(EvolutionarySearch, GetEpsGreedy) {
MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task);
// Ownership is transferred so don't delete mock_search_space
evolutionary_search.SetSearchSpace(mock_search_space);
std::vector<SearchState> search_states = evolutionary_search.SearchModuleExprEpsGreedy(options);
std::vector<SearchState> search_states =
evolutionary_search.SearchModuleExprEpsGreedy(options);
EXPECT_GE(search_states.size(), 1UL);
size_t expr_size = static_cast<size_t>(mock_search_space->GetModuleExprSize());
size_t expr_size =
static_cast<size_t>(mock_search_space->GetModuleExprSize());
for (const SearchState& state : search_states) {
EXPECT_EQ(state->ir_schedule.GetModule().GetExprs().size(), expr_size);
}
......@@ -142,7 +154,9 @@ TEST(EvolutionarySearch, GetEpsGreedy) {
TEST(EvolutionarySearch, Evolve) {
auto target = common::DefaultNVGPUTarget();
auto tasks = CreateTasks(tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}), target);
auto tasks = CreateTasks(
tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}),
target);
CHECK_EQ(tasks.size(), 1);
ExprCostModel cost_model;
std::vector<const ir::ModuleExpr*> cost_model_samples(1);
......@@ -161,7 +175,8 @@ TEST(EvolutionarySearch, Evolve) {
EvolutionarySearch evolutionary_search(tasks[0], cost_model, &db);
int num_population = 10;
std::vector<SearchState> init_sketch = evolutionary_search.TestInitSketch(num_population, "rule_prune");
std::vector<SearchState> init_sketch =
evolutionary_search.TestInitSketch(num_population, "rule_prune");
for (int i = 0; i < num_population; ++i) {
ir::ModuleExpr me(init_sketch[i]->ir_schedule.GetModule());
cost_model_samples[0] = &me;
......@@ -172,10 +187,12 @@ TEST(EvolutionarySearch, Evolve) {
for (auto s : init_sketch) {
VLOG(6) << "cost = " << s->predicted_cost;
}
std::vector<SearchState>*population_pre_ptr = &init_sketch, *population_next_ptr;
std::vector<SearchState>*population_pre_ptr = &init_sketch,
*population_next_ptr;
std::vector<SearchState> population;
for (int i = 0; i < 10; ++i) {
population = evolutionary_search.TestEvolve(*population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10);
population = evolutionary_search.TestEvolve(
*population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10);
population_next_ptr = &population;
VLOG(6) << "population[" << i + 1 << "] costs:";
double total_cost_pre = 0.0, total_cost_next = 0.0;
......
......@@ -34,11 +34,14 @@ class MutateRule {
* @param rand_seed The random seed for mutation.
* @return The mutated trace.
*/
virtual ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) = 0;
virtual ir::ScheduleDesc Apply(
const ir::ScheduleDesc& trace,
utils::LinearRandomEngine::StateType* rand_seed) = 0;
/**
* @brief Create a MutateRule with name.
* @param name The name of mutate rule, consisting of lowercase letters and underscores
* @param name The name of mutate rule, consisting of lowercase letters and
* underscores
* @return The created MutateRule.
*/
static std::unique_ptr<MutateRule> Make(const std::string& name);
......
......@@ -44,8 +44,11 @@ std::vector<SampledTile> FindSampledTiles(const ScheduleDesc& trace) {
break;
}
if (step.type == "SamplePerfectTile") {
std::vector<int> tile_factors = absl::get<std::vector<int>>(step.attrs.at("decision"));
CHECK(tile_factors.size() >= 2) << "factors size must be greater equal than 2, which is " << tile_factors.size();
std::vector<int> tile_factors =
absl::get<std::vector<int>>(step.attrs.at("decision"));
CHECK(tile_factors.size() >= 2)
<< "factors size must be greater equal than 2, which is "
<< tile_factors.size();
tiles.push_back(std::make_tuple(step, tile_factors, step_idx));
}
++step_idx;
......@@ -89,10 +92,13 @@ ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace,
// Step 2. Choose the divisor for mutate.
int divisor;
if (loop_y == split_size - 1) {
int max_innermost_factor = absl::get<int>(step.attrs.at("max_innermost_factor"));
int max_innermost_factor =
absl::get<int>(step.attrs.at("max_innermost_factor"));
int max_optional_factor_idx = optional_factors.size() - 1;
for (; max_optional_factor_idx > 0; --max_optional_factor_idx) {
if (optional_factors.at(max_optional_factor_idx) * tile_factors.at(loop_y) <= max_innermost_factor) {
if (optional_factors.at(max_optional_factor_idx) *
tile_factors.at(loop_y) <=
max_innermost_factor) {
break;
}
}
......@@ -103,27 +109,32 @@ ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace,
}
continue;
}
divisor = optional_factors.at(utils::SampleUniformInt(1, max_optional_factor_idx + 1, rand_seed));
divisor = optional_factors.at(
utils::SampleUniformInt(1, max_optional_factor_idx + 1, rand_seed));
} else {
divisor = optional_factors.at(utils::SampleUniformInt(1, optional_factors.size(), rand_seed));
divisor = optional_factors.at(
utils::SampleUniformInt(1, optional_factors.size(), rand_seed));
}
// Step 3. Determine the new tile value
VLOG(6) << "DoMutateTileSize: divisor = " << divisor << ", before mutate: \n"
<< "factors[" << loop_x << "] = " << tile_factors[loop_x] << ", factors[" << loop_y
<< "] = " << tile_factors[loop_y];
VLOG(6) << "DoMutateTileSize: divisor = " << divisor
<< ", before mutate: \n"
<< "factors[" << loop_x << "] = " << tile_factors[loop_x]
<< ", factors[" << loop_y << "] = " << tile_factors[loop_y];
tile_factors[loop_x] /= divisor;
tile_factors[loop_y] *= divisor;
VLOG(6) << "after mutate: \n"
<< "factors[" << loop_x << "] = " << tile_factors[loop_x] << ", factors[" << loop_y
<< "] = " << tile_factors[loop_y];
<< "factors[" << loop_x << "] = " << tile_factors[loop_x]
<< ", factors[" << loop_y << "] = " << tile_factors[loop_y];
// Step 4. Create a new step with new tile values and return the new trace
int step_idx = std::get<2>(tile);
return trace.ForkAndUpdate(step_idx, tile_factors, true);
}
}
ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, LinearRandomEngine::StateType* rand_seed) {
VLOG(6) << "Start applying MutateTileSize, old trace: \n" << trace.DebugString();
ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace,
LinearRandomEngine::StateType* rand_seed) {
VLOG(6) << "Start applying MutateTileSize, old trace: \n"
<< trace.DebugString();
std::vector<ScheduleDesc::Step> sample_tile_steps;
std::vector<std::vector<int>> sample_tile_data;
......@@ -132,9 +143,12 @@ ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, LinearRandomEngine
VLOG(6) << "MutateTileSize failed, try other mutate rules.";
return trace;
}
int sample_step_idx = utils::SampleUniformInt(0, sampled_tiles.size(), rand_seed);
auto new_trace = DoMutateTileSize(trace, sampled_tiles.at(sample_step_idx), rand_seed);
VLOG(6) << "End applying MutateTileSize, new trace: \n" << new_trace.DebugString();
int sample_step_idx =
utils::SampleUniformInt(0, sampled_tiles.size(), rand_seed);
auto new_trace =
DoMutateTileSize(trace, sampled_tiles.at(sample_step_idx), rand_seed);
VLOG(6) << "End applying MutateTileSize, new trace: \n"
<< new_trace.DebugString();
return new_trace;
}
......
......@@ -20,13 +20,16 @@ namespace cinn {
namespace auto_schedule {
/**
* The rule to mutate tile size, witch will modify the factors of the Split primitive.
* The rule to mutate tile size, witch will modify the factors of the Split
* primitive.
*/
class MutateTileSize : public MutateRule {
public:
MutateTileSize() = default;
ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) override;
ir::ScheduleDesc Apply(
const ir::ScheduleDesc& trace,
utils::LinearRandomEngine::StateType* rand_seed) override;
};
} // namespace auto_schedule
......
......@@ -42,17 +42,27 @@ TEST(MutateTileSize, Basic) {
Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C");
{M, N},
[&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({A, B, C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMutateTileSize_Basic", stages, {A, B, C}, {}, {}, nullptr, target, true);
lang::LowerVec("TestMutateTileSize_Basic",
stages,
{A, B, C},
{},
{},
nullptr,
target,
true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Original Expr: ";
VLOG(6) << ast_expr;
ir::ModuleExpr module_expr({ast_expr});
// We need to fix the seed as a constant to ensure that the result can be repeated.
// We need to fix the seed as a constant to ensure that the result can be
// repeated.
utils::LinearRandomEngine::StateType rand_seed = 123;
ir::IRSchedule ir_schedule(module_expr, rand_seed);
ir::IRSchedule new_ir_schedule(ir_schedule);
......@@ -64,10 +74,13 @@ TEST(MutateTileSize, Basic) {
// apply mutate
MutateTileSize mutator;
ir::ScheduleDesc sch_desc = mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed);
ir::ScheduleDesc sch_desc =
mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed);
sch_desc.Replay(&new_ir_schedule, true);
VLOG(6) << "Expr before mutate tile size: \n" << ir_schedule.GetModule().GetExprs()[0];
VLOG(6) << "Expr after mutate tile size: \n" << new_ir_schedule.GetModule().GetExprs()[0];
VLOG(6) << "Expr before mutate tile size: \n"
<< ir_schedule.GetModule().GetExprs()[0];
VLOG(6) << "Expr after mutate tile size: \n"
<< new_ir_schedule.GetModule().GetExprs()[0];
std::string target_new_ir = R"ROC({
ScheduleBlock(root)
......@@ -111,7 +124,8 @@ TEST(MutateTileSize, Basic) {
sch_desc = mutator.Apply(sch_desc, &rand_seed);
for (auto&& step : sch_desc.Steps()) {
if (step.type == "SamplePerfectTile") {
std::vector<int> tile_factors = absl::get<std::vector<int>>(step.attrs.at("decision"));
std::vector<int> tile_factors =
absl::get<std::vector<int>>(step.attrs.at("decision"));
ASSERT_EQ(tile_factors.size(), last_tile_factors.size());
ASSERT_NE(tile_factors[0], last_tile_factors[0]);
ASSERT_NE(tile_factors[1], last_tile_factors[1]);
......
......@@ -36,7 +36,8 @@ using ::cinn::hlir::framework::NodeData;
std::vector<TuneTask> TaskCreator::CreateTuneTaskOpLevel(Graph* graph) {
std::vector<TuneTask> ret_tasks;
const std::vector<std::shared_ptr<Graph::Group>>* groups = &graph->fusion_groups;
const std::vector<std::shared_ptr<Graph::Group>>* groups =
&graph->fusion_groups;
std::vector<std::shared_ptr<Graph::Group>> non_fused_groups;
// The input graph doesn't run Op Fusion
if (graph->fusion_groups.empty()) {
......
......@@ -45,7 +45,8 @@ class TaskOptimizer {
std::string from;
double cost;
FunctionGroup functions;
Result(const std::string& from_type) : from(from_type), cost(std::numeric_limits<double>::max()) {}
Result(const std::string& from_type)
: from(from_type), cost(std::numeric_limits<double>::max()) {}
};
Result OptimizeByManual(bool need_measure);
......@@ -53,7 +54,9 @@ class TaskOptimizer {
Result OptimizeByEvolution(const TuningOptions& options);
// call search candidates once by EvolutionarySearch and prune invalid ones
std::vector<SearchState> SearchOneRound(const TuningOptions& options, std::vector<MeasureInput>* measure_candidates);
std::vector<SearchState> SearchOneRound(
const TuningOptions& options,
std::vector<MeasureInput>* measure_candidates);
private:
// the max retry times if continuously get empty result
......
......@@ -31,7 +31,8 @@ struct InitialTaskInfo {
std::string task_key;
ir::ModuleExpr module_expr;
InitialTaskInfo(const std::string& task_key, const ir::ModuleExpr& module_expr)
InitialTaskInfo(const std::string& task_key,
const ir::ModuleExpr& module_expr)
: task_key(task_key), module_expr(module_expr) {}
};
......@@ -45,19 +46,25 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> {
// Get the initial ModuleExpr of a task.
inline const InitialTaskInfo* Get(const std::string& task_key) {
const InitialTaskInfo* task_info = Registry<InitialTaskInfo>::Find(task_key);
CHECK(task_info) << "InitialTaskInfo [" << task_key << "] is not registered";
const InitialTaskInfo* task_info =
Registry<InitialTaskInfo>::Find(task_key);
CHECK(task_info) << "InitialTaskInfo [" << task_key
<< "] is not registered";
return task_info;
}
// Check if the task info with task_key exists;
inline const bool Has(const std::string& task_key) { return nullptr != Registry<InitialTaskInfo>::Find(task_key); }
inline const bool Has(const std::string& task_key) {
return nullptr != Registry<InitialTaskInfo>::Find(task_key);
}
// Regist the initial ModuleExpr of a task into the map
inline void Regist(const std::string& task_key, const ir::ModuleExpr& module_expr) {
inline void Regist(const std::string& task_key,
const ir::ModuleExpr& module_expr) {
std::lock_guard<std::mutex> guard(registering_mutex);
if (fmap_.count(task_key) == 0) {
InitialTaskInfo* task_info = new InitialTaskInfo(task_key, optim::IRCopy(module_expr));
InitialTaskInfo* task_info =
new InitialTaskInfo(task_key, optim::IRCopy(module_expr));
__REGISTER__(task_key, task_info);
}
}
......@@ -67,7 +74,8 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> {
CINN_DISALLOW_COPY_AND_ASSIGN(InitialTaskRegistry);
// Regist the initial ModuleExpr of a task.
inline InitialTaskInfo* __REGISTER__(const std::string& task_key, InitialTaskInfo* task_info) {
inline InitialTaskInfo* __REGISTER__(const std::string& task_key,
InitialTaskInfo* task_info) {
fmap_[task_key] = task_info;
const_list_.push_back(task_info);
entry_list_.push_back(task_info);
......
......@@ -27,7 +27,9 @@ int EfficiencyPriority::NextTaskId() {
return -1;
}
bool EfficiencyPriority::IsTaskToTune(const TuneTask* task) { return config_.minimum_gain_threshold > 0.0; }
bool EfficiencyPriority::IsTaskToTune(const TuneTask* task) {
return config_.minimum_gain_threshold > 0.0;
}
} // namespace auto_schedule
} // namespace cinn
......@@ -25,7 +25,8 @@ namespace auto_schedule {
// is picking a task with the maximum earnings ratio.
class EfficiencyPriority : public TaskScheduler {
public:
EfficiencyPriority(const std::vector<TuneTask>& tasks, const Config& config) : TaskScheduler(tasks, config) {}
EfficiencyPriority(const std::vector<TuneTask>& tasks, const Config& config)
: TaskScheduler(tasks, config) {}
const char* Name() const override { return "efficiency_priority"; };
......
......@@ -25,7 +25,8 @@ namespace auto_schedule {
// is picking a task to tune once a time iteratively.
class RoundRobin : public TaskScheduler {
public:
RoundRobin(const std::vector<TuneTask>& tasks, const Config& config) : TaskScheduler(tasks, config) {}
RoundRobin(const std::vector<TuneTask>& tasks, const Config& config)
: TaskScheduler(tasks, config) {}
const char* Name() const override { return "round_robin"; };
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册