diff --git a/CMakeLists.txt b/CMakeLists.txt index 6808d0a214aea3ff3c803bc15af42b106552f258..099e13c67c2bb8987cf8b30a97c704e2d92084a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,9 @@ set(AKG_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") include(cmake/RT.cmake) include(cmake/utils.cmake) include(cmake/external_libs/isl.cmake) + +set(ISL_DIR "${CMAKE_BINARY_DIR}/isl") + if(ENABLE_AKG) message("-- Build akg in Mindspore") execute_process(COMMAND bash ${AKG_SOURCE_DIR}/third_party/apply_patches.sh ${CMAKE_CURRENT_BINARY_DIR} "1") @@ -43,8 +46,6 @@ else() set(UNITTEST_DIR "${AKG_SOURCE_DIR}/tests/unittest_cpp") endif() -set(ISL_DIR "${CMAKE_BINARY_DIR}/isl") - file(COPY ${AKG_SOURCE_DIR}/python/akg DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) @@ -175,6 +176,8 @@ file( ${TVM_DIR}/src/runtime/vm/profiler/*.cc ${TVM_DIR}/src/codegen/stackvm/*.cc ${AKG_SOURCE_DIR}/src/poly/*.cc + ${AKG_SOURCE_DIR}/src/poly/schedule_pass/*.cc + ${AKG_SOURCE_DIR}/src/poly/tiling/*.cc ${AKG_SOURCE_DIR}/src/api/*.cc ${AKG_SOURCE_DIR}/src/pass/*.cc ${AKG_SOURCE_DIR}/src/rpc/*.cc diff --git a/src/pass/utils.cc b/src/pass/utils.cc index 0b1ba852b0d68eea43b029fd66dd1d7f315b4ded..f0a95b1ed8d964b4bacde1cd990d2c8040940c31 100644 --- a/src/pass/utils.cc +++ b/src/pass/utils.cc @@ -29,7 +29,7 @@ #include "ir_pass.h" #include "pass/utils.h" #include "pass/expr_alg_simplify.h" -#include "poly/tiling_algorithm.h" +#include "poly/tiling/tiling_algorithm.h" namespace akg { namespace ir { diff --git a/src/poly/cce_isl_emitter.cc b/src/poly/cce_isl_emitter.cc index b75b0d2944714b6f435061990da012604ce93ef5..7a674affbfe4d0b59dd5cd3dff22909a77a0d849 100644 --- a/src/poly/cce_isl_emitter.cc +++ b/src/poly/cce_isl_emitter.cc @@ -13,19 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "poly/cce_isl_emitter.h" -#include -#include -#include -#include -#include -#include +#include "poly/cce_isl_emitter.h" #include "ir_pass.h" -#include "poly/isl.h" -#include "poly/poly_util.h" +#include "poly/dma_inject.h" #include "pass/utils.h" +#include "poly/spec_gemm_builder.h" namespace akg { namespace ir { @@ -403,7 +397,7 @@ std::vector GetLhsAllArgs(const CCEIslEmitter *emitter, const isl::ast_ auto node_id = node.get_annotation(); isl::ast_expr_op node_op; std::vector arg_ids; - if (!emitter->scop_.IsRead(stmt_id) && !emitter->scop_.IsWrite(stmt_id)) { + if (!emitter->info_.IsRead(stmt_id) && !emitter->info_.IsWrite(stmt_id)) { node_op = node.get_expr().as(); if (!node_op) return arg_ids; } else { @@ -412,9 +406,9 @@ std::vector GetLhsAllArgs(const CCEIslEmitter *emitter, const isl::ast_ auto hoisted = iterator_map.range_factor_range(); auto original = iterator_map.range_factor_domain().range_factor_range(); auto build = emitter->node_info_map_.at(node_id).build; - if (emitter->scop_.IsRead(stmt_id)) { + if (emitter->info_.IsRead(stmt_id)) { node_expr = build.access_from(isl::multi_pw_aff(hoisted)); - } else if (emitter->scop_.IsWrite(stmt_id)) { + } else if (emitter->info_.IsWrite(stmt_id)) { node_expr = build.access_from(isl::multi_pw_aff(original)); } node_op = node_expr.as(); @@ -505,7 +499,7 @@ bool CCEIslEmitter::InjectMulticore(const std::string &iter) { if (should_insert_multi_core) { ++multicore_info.multicore_depth; --multicore_info.coincidence[coincident_member]; - } + } } } else { LOG(WARNING) << "multicore: unrecognized loop var " << iter; @@ -545,7 +539,7 @@ Stmt CCEIslEmitter::EmitFor(const isl::ast_node_for &node) { Stmt stmt; if (body_stmt.defined()) { stmt = For::make(iter_expr, init_expr, cond_expr, ForType::Serial, DeviceAPI::None, body_stmt); - if (scop_.optimize_for_davinci_) { + if (info_.user_config_.GetOptimizeForDavinci()) { const int DAVINCIC0SIZE = 16; // need to find the last axis if (Equal(cond_expr, Expr(DAVINCIC0SIZE)) && ForShouldPassDown(this, node, isl_iter_id)) { @@ -567,26 +561,10 @@ Stmt CCEIslEmitter::EmitFor(const isl::ast_node_for &node) { return stmt; } -Stmt CCEIslEmitter::EmitIf(const isl::ast_node_if &node) { - // get cond - Expr cond_expr = Interpret(node.get_cond()); - - cur_if_list_.push_back(cond_expr.get()); - - // get then_case - Stmt then_case = EmitAst(node.get_then_node()); - - // get else_case - Stmt else_case; - if (node.has_else_node()) { - else_case = EmitAst(node.get_else_node()); - } - - cur_if_list_.pop_back(); - return IfThenElse::make(cond_expr, then_case, else_case); -} - Expr CCEIslEmitter::EmitLoad(const isl::ast_expr &expr, const Type type) { + if (PRINT_CCE_ISL_EMMITER) { + LOG(INFO) << ">>>>>>>>>>>>INPUT AST_NODE[LOAD]<<<<<<<<<<<<<<\n" << expr; + } if (auto op = expr.as()) { if (auto access = op.as()) { // make buffer, index @@ -596,24 +574,32 @@ Expr CCEIslEmitter::EmitLoad(const isl::ast_expr &expr, const Type type) { for (unsigned int i = 1; i < op.get_n_arg(); ++i) { local_args.push_back(Interpret(op.get_arg(i))); } - if (scop_.CountBufferDefInfo(var)) { + if (info_.analysis_result_.CountBufferDefInfo(var)) { realize_use_.insert(var); if (!if_map_.count(var) || !AOutThanB(if_map_.at(var), cur_if_list_)) { realize_use_with_may_def_.insert(var); } } - Tensor t = scop_.FindTensor(var); - if (scop_.IsIm2col()) { + Tensor t = info_.FindTensor(var); + if (info_.cube_info_.IsIm2col()) { // compute_local_UB find compute std::string name = t->op->name; - for (const auto &updateTensor : scop_.data_.update_tensors) { + for (const auto &updateTensor : info_.analysis_result_.GetUpdateTensor()) { if (updateTensor->op->name == name) { - return Call::make(type, updateTensor->op->name, local_args, Call::CallType::Halide, updateTensor->op, - updateTensor->value_index); + auto call = Call::make(type, updateTensor->op->name, local_args, Call::CallType::Halide, updateTensor->op, + updateTensor->value_index); + if (PRINT_CCE_ISL_EMMITER) { + LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << call; + } + return call; } } } - return Call::make(type, t->op->name, local_args, Call::CallType::Halide, t->op, t->value_index); + auto call = Call::make(type, t->op->name, local_args, Call::CallType::Halide, t->op, t->value_index); + if (PRINT_CCE_ISL_EMMITER) { + LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << call; + } + return call; } } return Expr(); @@ -633,27 +619,72 @@ static isl_stat ExtractSinglePiece(__isl_take isl_set *set, __isl_take isl_aff * return isl_stat_error; } -Stmt CCEIslEmitter::EmitL1Read(const isl::ast_node_user &node) { +static isl::pw_multi_aff ComputeNewBufferFootprint(const std::shared_ptr &fp_cluster, + const isl::pw_multi_aff &buffer_footprint) { + if (!fp_cluster->UnWriteable()) return buffer_footprint; + if (!fp_cluster->foot_print_.is_valid) return buffer_footprint; + unsigned num_dims = fp_cluster->foot_print_.GetBoxDim(); + + isl::pw_multi_aff new_buffer_footprint = buffer_footprint; + for (unsigned dim = 0; dim < num_dims; ++dim) { + isl::aff lower_bound = fp_cluster->foot_print_.GetBoxLowerBound(dim); + isl::pw_aff dim_buf_fp = buffer_footprint.get_pw_aff(dim); + if (dim_buf_fp.n_piece() != 1) return buffer_footprint; + // there is only one piece, but we have to use the foreach API + dim_buf_fp.foreach_piece([&lower_bound, &new_buffer_footprint, &dim](const isl::set &set, + const isl::aff &aff) -> void { + if (IsAffVarPlusOffset(lower_bound) && IsAffNonZeroConst(aff)) { + isl::pw_aff zero = isl::pw_aff(isl::manage(isl_aff_set_constant_si(aff.copy(), 0))); + new_buffer_footprint = isl::manage(isl_pw_multi_aff_set_pw_aff(new_buffer_footprint.copy(), dim, zero.copy())); + } + }); + } + return new_buffer_footprint; +} + +/* + * Remove the constant offset from provide args, e.g. input_1_local_UB(32, 7, cc2, cc3) = input_1(...) + * Check the footprint cluster of the hoisted var to confirm this input tensor has multiple accesses + * from shifted tiles. This should be improved by computing the new footprint with footprint_per_access(), + * but from isl AST we do not know the footprint ID that corresponds to the GM -> UB copy. + */ +isl::pw_multi_aff RemoveConstOffsetFromBufferFootprint( + const isl::pw_multi_aff &buffer_footprint, + const std::vector> &active_buffer_footprints) { + const isl::id buffer_id = buffer_footprint.get_tuple_id(isl_dim_out); + for (const auto &act_buf : active_buffer_footprints) { + if (act_buf.second.cluster_id == buffer_id) { + const auto &footprint_cluster = act_buf.second.cluster; + return ComputeNewBufferFootprint(footprint_cluster, buffer_footprint); + } + } + return buffer_footprint; +} + +Stmt CCEIslEmitter::EmitRead(const isl::ast_node_user &node) { isl::id node_id = node.get_annotation(); isl::pw_multi_aff iterator_map = node_info_map_.at(node_id).iterator_map; isl::pw_multi_aff hoisted = iterator_map.range_factor_range(); isl::pw_multi_aff original = iterator_map.range_factor_domain().range_factor_range(); isl::id original_tensor = original.get_tuple_id(isl_dim_out); - bool isInputTensor = scop_.FindTensorInOrig(original_tensor).defined(); - if (isInputTensor) hoisted = scop_.RemoveConstOffsetFromBufferFootprint(hoisted); + bool isInputTensor = info_.FindTensorInOrig(original_tensor).defined(); + if (isInputTensor) + hoisted = RemoveConstOffsetFromBufferFootprint(hoisted, info_.analysis_result_.ActiveBufferFootprints()); auto build = node_info_map_.at(node_id).build; auto lhs = build.access_from(isl::multi_pw_aff(hoisted)); auto rhs = build.access_from(isl::multi_pw_aff(original)); - size_t pos = scop_.GetBName().find("_local"); - std::string b_name = pos == std::string::npos ? scop_.GetBName() : scop_.GetBName().substr(0, pos); + size_t pos = info_.cube_info_.GetBName().find("_local"); + std::string b_name = + pos == std::string::npos ? info_.cube_info_.GetBName() : info_.cube_info_.GetBName().substr(0, pos); auto b_l1_name = b_name + "_local_L1"; - if (scop_.matB_dim_h_ > 0 && scop_.matB_dim_w_ > 0 && original.get_tuple_id(isl_dim_out).get_name() == b_l1_name) { - auto h_size = scop_.matB_dim_h_; - auto w_size = scop_.matB_dim_w_; + if (info_.user_config_.GetMatBDimH() > 0 && info_.user_config_.GetMatBDimW() > 0 && + original.get_tuple_id(isl_dim_out).get_name() == b_l1_name) { + auto h_size = info_.user_config_.GetMatBDimH(); + auto w_size = info_.user_config_.GetMatBDimW(); auto mpa = isl::multi_pw_aff(original); auto size = mpa.size(); @@ -696,7 +727,7 @@ Stmt CCEIslEmitter::EmitL1Read(const isl::ast_node_user &node) { rhs = build.access_from(isl::multi_pw_aff(ma)); } - Type type = scop_.GetDtypeOf(rhs); + Type type = info_.GetDtypeOf(rhs); if (auto op = lhs.as()) { if (auto access = op.as()) { Expr value = EmitLoad(rhs, type); @@ -707,9 +738,9 @@ Stmt CCEIslEmitter::EmitL1Read(const isl::ast_node_user &node) { local_args.push_back(Interpret(op.get_arg(i))); } - Tensor t = scop_.FindTensor(var); + Tensor t = info_.FindTensor(var); CHECK(t.defined()); - if (scop_.CountBufferDefInfo(var)) { + if (info_.analysis_result_.CountBufferDefInfo(var)) { realize_may_def_.insert(var); if_map_.emplace(var, cur_if_list_); if (cur_if_list_.empty()) { @@ -717,8 +748,8 @@ Stmt CCEIslEmitter::EmitL1Read(const isl::ast_node_user &node) { } } hoisted_read_.insert(var); - if (scop_.IsIm2col() && !scop_.data_.update_tensors.empty()) { - return Provide::make(scop_.data_.update_tensors[0]->op, 0, value, local_args); + if (info_.cube_info_.IsIm2col() && !info_.analysis_result_.GetUpdateTensor().empty()) { + return Provide::make(info_.analysis_result_.GetUpdateTensor()[0]->op, 0, value, local_args); } return Provide::make(t->op, 0, value, local_args); } @@ -726,7 +757,7 @@ Stmt CCEIslEmitter::EmitL1Read(const isl::ast_node_user &node) { return Stmt(); } -Stmt CCEIslEmitter::EmitL1Write(const isl::ast_node_user &node, Scop::AtomicType atomic) { +Stmt CCEIslEmitter::EmitWrite(const isl::ast_node_user &node, AtomicType atomic) { auto node_id = node.get_annotation(); CHECK_GT(node_info_map_.count(node_id), 0); auto iterator_map = node_info_map_.at(node_id).iterator_map; @@ -735,18 +766,18 @@ Stmt CCEIslEmitter::EmitL1Write(const isl::ast_node_user &node, Scop::AtomicType // refine atomic from reduce op bool doatomic = false; - if (atomic == Scop::AtomicType::Add) { + if (atomic == AtomicType::Add) { auto srcid = original.get_tuple_id(isl_dim_out); - for (const auto &i : scop_.data_.statements) { + for (const auto &i : info_.analysis_result_.GetStatementMap()) { std::set rmv; const auto provide = static_cast(i.second); - if (provide == nullptr || scop_.data_.reduces.count(provide) != 1) continue; + if (provide == nullptr || info_.analysis_result_.GetReduceMap().count(provide) != 1) continue; if (provide->func->func_name() != srcid.get_name()) continue; doatomic = true; if (!stmt_var_map_.count(i.first)) continue; VarMap vmap = stmt_var_map_.at(i.first); for (const auto &j : vmap) { - for (auto k : scop_.data_.reduces.at(provide)) { + for (auto k : info_.analysis_result_.GetReduceMap().at(provide)) { if (k->var.get()->name_hint != j.first.get_name()) continue; std::vector iters = ExtractIterfromExpr().Run(j.second); for (auto v : iters) @@ -761,7 +792,7 @@ Stmt CCEIslEmitter::EmitL1Write(const isl::ast_node_user &node, Scop::AtomicType auto build = node_info_map_.at(node_id).build; auto rhs = build.access_from(isl::multi_pw_aff(hoisted)); auto lhs = build.access_from(isl::multi_pw_aff(original)); - Type type = scop_.GetDtypeOf(lhs); + Type type = info_.GetDtypeOf(lhs); if (auto op = lhs.as()) { if (auto access = op.as()) { @@ -773,9 +804,9 @@ Stmt CCEIslEmitter::EmitL1Write(const isl::ast_node_user &node, Scop::AtomicType local_args.push_back(Interpret(op.get_arg(static_cast(i)))); } - Tensor t = scop_.FindTensor(var); + Tensor t = info_.FindTensor(var); CHECK(t.defined()); - if (scop_.CountBufferDefInfo(var)) { + if (info_.analysis_result_.CountBufferDefInfo(var)) { realize_may_def_.insert(var); if_map_.emplace(var, cur_if_list_); if (cur_if_list_.empty()) { @@ -791,7 +822,7 @@ Stmt CCEIslEmitter::EmitL1Write(const isl::ast_node_user &node, Scop::AtomicType } // remove original copy out promotion statement because it is sinked into if stmt of computation - if (scop_.conditional_write_buffer_footprints_.count(t->op->name)) return Evaluate::make(0); + if (info_.analysis_result_.GetConditionalWriteBufferFootprints().count(t->op->name)) return Evaluate::make(0); return Provide::make(t->op, 0, value, local_args); } @@ -804,15 +835,46 @@ Stmt CCEIslEmitter::EmitUserStmt(const isl::ast_node_user &node) { LOG(INFO) << "don't emit conv origin user stmt."; return Evaluate::make(Expr(0)); } else { - auto user_stmt = IslEmitter::EmitUserStmt(node); + CHECK(node.get_expr().isa()); + isl::ast_expr_op usr_expr = node.get_expr().as(); + stmt_id_ = usr_expr.get_arg(0).as().get_id(); + node_id_ = node.get_annotation(); + const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(stmt_id_); + CHECK(stmt_node); + // compute VarMap to replace old iterators + auto build = node_info_map_.at(node_id_).build; + auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id_).tuple; + auto iterator_map = node_info_map_.at(node_id_).iterator_map; + + var_map_.clear(); + for (unsigned int i = 0; i < tuple.size(); ++i) { + isl::id isl_old_iter = tuple.get_id(i); + auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i)); + Expr halide_new_iter = Interpret(isl_expr); + var_map_.emplace(isl_old_iter, halide_new_iter); + std::string replace_id = isl_old_iter.get_name() + "_"; + std::vector vec = ExtractIterfromExpr().Run(halide_new_iter); + for (auto item : vec) { + std::string new_name = item->name_hint; + auto iter_prefix = info_.user_config_.GetIterPrefix(info_.cube_info_.IsSpecGemm()); + size_t pos = new_name.find(iter_prefix); + if (pos != std::string::npos) { + new_name = new_name.replace(pos, iter_prefix.size(), replace_id); + iters_old_name_.emplace(item, item->name_hint); + iters_new_name_.emplace(item, new_name); + } + } + } + + VarMap vmap = var_map_; + stmt_var_map_.emplace(stmt_id_, vmap); + auto user_stmt = EmitUserStmtContent(stmt_node); // fix conv prefusion dma if condition bool add_attr = false; - const Node *stmt_node = scop_.data_.statements.at(stmt_id_); - CHECK(stmt_node); std::string type_key = std::string(stmt_node->GetTypeKey()); - if (!scop_.is_spec_gemm_ && (type_key == "IfThenElse")) { - isl::union_set transfer_stmt = scop_.data_.transfer_stmt; + if (!info_.cube_info_.IsSpecGemm() && (type_key == "IfThenElse")) { + isl::union_set transfer_stmt = info_.analysis_result_.GetTransferStmt(); if (!transfer_stmt.is_empty()) { transfer_stmt.foreach_set([&add_attr, this](const isl::set &s) -> void { if (s.get_tuple_name() == stmt_id_.get_name()) { @@ -828,6 +890,24 @@ Stmt CCEIslEmitter::EmitUserStmt(const isl::ast_node_user &node) { } } +AtomicType GetAtomicWrite(const isl::id &id, const StatementMap &statements) { + for (const auto &i : statements) { + const Node *stmt_node = i.second; + if (stmt_node->IsInstance()) { + auto provide = static_cast(stmt_node); + if (const auto cop = provide->func.as()) { + if (cop->attrs.count(ATTR_ATOMIC_ADD) != 0) { + if (auto str_op = cop->attrs.at(ATTR_ATOMIC_ADD).as()) { + auto str = str_op->value; + if (str == id.get_name()) return AtomicType::Add; + } + } + } + } + } + return AtomicType::Equ; +} + Stmt CCEIslEmitter::EmitStmt(const isl::ast_node_user &node) { CHECK(node.get_expr().isa()); isl::ast_expr_op usr_expr = node.get_expr().as(); @@ -835,16 +915,28 @@ Stmt CCEIslEmitter::EmitStmt(const isl::ast_node_user &node) { auto stmt_id = usr_expr.get_arg(0).as().get_id(); auto node_id = node.get_annotation(); - if (scop_.IsRead(stmt_id)) { - return EmitL1Read(node); - } else if (scop_.IsWrite(stmt_id)) { - if (scop_.IsGMWrite(stmt_id)) { + if (info_.IsRead(stmt_id)) { + auto s = EmitRead(node); + if (PRINT_CCE_ISL_EMMITER) { + LOG(INFO) << ">>>>>>>>>>>>INPUT AST_NODE[READ]<<<<<<<<<<<<<<\n" << node; + LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << s; + } + return s; + } else if (info_.IsWrite(stmt_id)) { + auto s = Stmt(); + if (info_.IsGMWrite(stmt_id)) { auto iterator_map = node_info_map_.at(node_id).iterator_map; auto original = iterator_map.range_factor_domain().range_factor_range(); auto srcid = original.get_tuple_id(isl_dim_out); - return EmitL1Write(node, scop_.GetAtomicWrite(srcid)); + s = EmitWrite(node, GetAtomicWrite(srcid, info_.analysis_result_.GetStatementMap())); + } else { + s = EmitWrite(node, AtomicType::Equ); + } + if (PRINT_CCE_ISL_EMMITER) { + LOG(INFO) << ">>>>>>>>>>>>INPUT AST_NODE[WRITE]<<<<<<<<<<<<<<\n" << node; + LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << s; } - return EmitL1Write(node, Scop::AtomicType::Equ); + return s; } else { SetCube(stmt_id); return EmitUserStmt(node); @@ -852,7 +944,7 @@ Stmt CCEIslEmitter::EmitStmt(const isl::ast_node_user &node) { } void CCEIslEmitter::SetCube(const isl::id &stmt_id) { - auto cur_op = scop_.data_.stmt_op_Info.at(stmt_id); + auto cur_op = info_.analysis_result_.GetStmtOpInfoMap().at(stmt_id); opinfo_.isCube = cur_op.isCube || opinfo_.isCube; opinfo_.ops.insert(opinfo_.ops.end(), cur_op.ops.begin(), cur_op.ops.end()); is_cube_ = true; @@ -1002,9 +1094,10 @@ Stmt CCEIslEmitter::EmitGemmRangeInfoBackPropFilter(const Stmt &stmt) { return AttrStmt::make(range_map, "pragma_gemm_l0", Expr(l0_range_idx), stmt); } -void CCEIslEmitter::EmitGemmRangeInfoNewAxis(std::vector &range, std::vector &prefix, - std::unordered_map &outerAxis, Range &axisMRange, - Map &range_map, Map &axis_map) { +void CCEIslEmitter::CollectGemmRangeInfoNewAxis(std::vector &range, std::vector &prefix, + std::unordered_map &outerAxis, Range &axisMRange, + Map &range_map, + Map &axis_map) { for (unsigned int i = 0; i < range.size(); i++) { std::stringstream ss; ss << "ee" << i; @@ -1109,7 +1202,7 @@ Stmt CCEIslEmitter::EmitGemmRangeInfo(Stmt stmt) { ***********************/ // spec gemm set dim outer outer range std::vector range; - if (scop_.tile_size_is_var_) { + if (info_.user_config_.GetTileSizeIsVar()) { // must equal to scop.cc const int t0_mo = 11; const int t0_ko = 13; @@ -1120,7 +1213,7 @@ Stmt CCEIslEmitter::EmitGemmRangeInfo(Stmt stmt) { range.emplace_back(Expr(0), Expr(1)); range.emplace_back(Expr(0), floordiv(Var("KO") + t0_ko - 1, t0_ko)); } else { - range = scop_.GetRange(range_idx_); + range = info_.cube_info_.GetRange(range_idx_); } Map range_map; Map axis_map; @@ -1131,7 +1224,7 @@ Stmt CCEIslEmitter::EmitGemmRangeInfo(Stmt stmt) { CHECK(prefix.size() == range.size()); Range axis_m_range; - EmitGemmRangeInfoNewAxis(range, prefix, outer_axis, axis_m_range, range_map, axis_map); + CollectGemmRangeInfoNewAxis(range, prefix, outer_axis, axis_m_range, range_map, axis_map); std::vector all_axis; all_axis.emplace_back("mo_"); @@ -1153,10 +1246,10 @@ Stmt CCEIslEmitter::EmitGemmRangeInfo(Stmt stmt) { * **********************************/ PartitionSingle *single = PartitionSingle::getInstance(); if (single != nullptr) { - if (!scop_.is_dynamic_) { - EmitGemmRangeInfoDynamic(axis_m_range, range_map); + if (!info_.user_config_.GetIsDynamic()) { + CollectGemmMWSize(axis_m_range, range_map); } else { - EmitGemmRangeInfoStatic(range_map); + CollectGemmMWSizeDynamic(range_map); } } stmt = AttrStmt::make(axis_map, "pragma_spec_gemm_attr", Expr(0), stmt); @@ -1165,7 +1258,7 @@ Stmt CCEIslEmitter::EmitGemmRangeInfo(Stmt stmt) { return stmt; } -void CCEIslEmitter::EmitGemmRangeInfoDynamic(Range &axis_m_range, Map &range_map) { +void CCEIslEmitter::CollectGemmMWSize(Range &axis_m_range, Map &range_map) { std::map fractal_int_info = PartitionSingle::getFractalInfo(); CHECK(fractal_int_info.find(ATTR_CONV_GMM_M) != fractal_int_info.end()); CHECK(fractal_int_info.find(ATTR_CONV_TILE_M) != fractal_int_info.end()); @@ -1203,7 +1296,7 @@ void CCEIslEmitter::EmitGemmRangeInfoDynamic(Range &axis_m_range, Map &range_map) { +void CCEIslEmitter::CollectGemmMWSizeDynamic(Map &range_map) { std::map fractal_int_info = PartitionSingle::getFractalInfo(); CHECK(fractal_int_info.find(ATTR_CONV_GMM_M) != fractal_int_info.end()); CHECK(fractal_int_info.find(ATTR_CONV_TILE_M) != fractal_int_info.end()); @@ -1232,8 +1325,8 @@ void CCEIslEmitter::EmitGemmRangeInfoStatic(Map &range_map) } std::string CCEIslEmitter::FindRealizeScopeToString(const isl::id &var) { - if (scop_.CountBufferDefInfo(var)) { - auto tensor_info = scop_.GetBufferDefInfo(var); + if (info_.analysis_result_.CountBufferDefInfo(var)) { + auto tensor_info = info_.analysis_result_.GetBufferDefInfo(var); MemType mem_type = tensor_info.DstMemType(); switch (mem_type) { @@ -1277,11 +1370,11 @@ Stmt CCEIslEmitter::InsertRealize(Stmt stmt, const isl::id &var, bool is_L0) { // A tensor may be defined multiple times in BufferDefInfo due to nested realize. // Because we cannot determine which one we actually want, we have to be conservative here // and allocate space for the largest shape to avoid overflow. - Tensor t = scop_.FindTensorWithLargestShape(var); + Tensor t = info_.FindTensorWithLargestShape(var); Region bounds; - if (scop_.IsCUB(var.get_name())) { - auto ct = scop_.FindTensor(var.get_name() + "_local_L0C"); + if (info_.cube_info_.IsCUB(var.get_name())) { + auto ct = info_.FindTensor(var.get_name() + "_local_L0C"); for (auto j : ct->shape) { bounds.push_back(Range::make_by_min_extent(Expr(0), j)); } @@ -1294,18 +1387,18 @@ Stmt CCEIslEmitter::InsertRealize(Stmt stmt, const isl::id &var, bool is_L0) { } // If isolate, make a new buffer - auto buf = scop_.binds_.at(t); + auto buf = info_.user_config_.GetBind().at(t); auto tt = placeholder(t->shape, t->dtype, t->op->name); stmt = TensorSubstitute(stmt, t->op, tt->op, tt->value_index); t = tt; - if (scop_.CountBufferDefInfo(var)) { - auto decl = scop_.GetBufferDefInfo(var); + if (info_.analysis_result_.CountBufferDefInfo(var)) { + auto decl = info_.analysis_result_.GetBufferDefInfo(var); decl.tensor = t; } - scop_.binds_.Set(t, buf); + info_.user_config_.SetBind(t, buf); - if (!scop_.IsIm2col()) { + if (!info_.cube_info_.IsIm2col()) { stmt = TensorSubstitute2(stmt, t->op->func_name(), t->op, t->value_index); } @@ -1321,8 +1414,8 @@ Stmt CCEIslEmitter::InsertRealize(Stmt stmt, const isl::id &var, bool is_L0) { } } - if (scop_.IsIm2col()) { - for (const auto &curTensor : scop_.data_.update_tensors) { + if (info_.cube_info_.IsIm2col()) { + for (const auto &curTensor : info_.analysis_result_.GetUpdateTensor()) { // find the updateTensor with same name and make Realize and AttrStmt if (curTensor->op->name == t->op->name) { stmt = Realize::make(curTensor->op, t->value_index, t->dtype, bounds, const_true(1), stmt); @@ -1339,25 +1432,25 @@ Stmt CCEIslEmitter::InsertRealize(Stmt stmt, const isl::id &var, bool is_L0) { return stmt; } -Stmt HoistL0write(Scop &scop, const Stmt &body, std::vector &l0write) { +Stmt HoistL0write(ScopInfo &info, const Stmt &body, std::vector &l0write) { Stmt stmt = body; if (!l0write.empty()) { - if (scop.IsGemm()) { - auto f = HoistL0Write(scop.binds_orig_, l0write.back()); + if (info.cube_info_.IsGemm()) { + auto f = HoistL0Write(info.user_config_.GetOriginBind(), l0write.back()); static_cast(f.Mutate(body)); f.mutate_ = true; stmt = f.Mutate(body); if (!f.found_) stmt = Block::make(body, l0write.back()); - } else if (scop.is_spec_gemm_) { + } else if (info.cube_info_.IsSpecGemm()) { stmt = Block::make(body, l0write.back()); } } return stmt; } -void CCEIslEmitter::ProcBypassL1(const Scop &scop) { +void CCEIslEmitter::ProcBypassL1(const ScopInfo &info) { if (0 == bypassL1_) { - bypassL1_ = scop_.bypassL1_; + bypassL1_ = info.user_config_.GetByPassL1(); } } @@ -1365,8 +1458,8 @@ Stmt CCEIslEmitter::EmitSpecGemL1write(const isl::ast_node_mark &node, const Stm is_old_gemm_l1write_ = true; static_cast(EmitAst(node.get_node())); is_old_gemm_l1write_ = false; - if (!scop_.is_spec_gemm_ && !scop_.old_l1_write_.empty()) { - return Block::make(stmt, scop_.old_l1_write_.back()); + if (!info_.cube_info_.IsSpecGemm() && !info_.cube_info_.GetOldL1Write().empty()) { + return Block::make(stmt, info_.cube_info_.GetOldL1Write().back()); } return stmt; } @@ -1374,13 +1467,13 @@ Stmt CCEIslEmitter::EmitSpecGemL1write(const isl::ast_node_mark &node, const Stm void CCEIslEmitter::EmitAttrStmtAfterRealize(bool is_L1, bool is_L0, std::vector &stmts) { // Emit attrs of provide if (is_L1) { - for (const auto &i : scop_.data_.stmt_op_Info) { + for (const auto &i : info_.analysis_result_.GetStmtOpInfoMap()) { if (!i.second.isCube) continue; - const Node *stmt_node = scop_.data_.statements.at(i.first); + const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(i.first); if (!stmt_node->IsInstance()) continue; const auto provide = static_cast(stmt_node); - if (!scop_.attr_info_.empty()) { - stmts[0] = AttrStmt::make(scop_.attr_info_, "pragma_attrs", Expr(1), stmts[0]); + if (!info_.cube_info_.GetConvAttrInfo().empty()) { + stmts[0] = AttrStmt::make(info_.cube_info_.GetConvAttrInfo(), "pragma_attrs", Expr(1), stmts[0]); } else if (const auto cop = provide->func.as()) { stmts[0] = AttrStmt::make(cop->attrs, "pragma_attrs", Expr(1), stmts[0]); } @@ -1389,8 +1482,8 @@ void CCEIslEmitter::EmitAttrStmtAfterRealize(bool is_L1, bool is_L0, std::vector } } - if (scop_.is_spec_gemm_ && is_L0) { - if (scop_.conv_back_prop_filter_) { + if (info_.cube_info_.IsSpecGemm() && is_L0) { + if (info_.user_config_.GetConvBackPropFilter()) { stmts[0] = EmitGemmRangeInfoBackPropFilter(stmts[0]); } else { stmts[0] = EmitGemmRangeInfo(stmts[0]); @@ -1399,21 +1492,24 @@ void CCEIslEmitter::EmitAttrStmtAfterRealize(bool is_L1, bool is_L0, std::vector } void CCEIslEmitter::GemmTranspose(std::vector &stmts) { - if (scop_.IsGemmDataTranspose()) { - bool transBlock = !scop_.IsGemmDataTransposeInnerBlock(); - bool transIn = !scop_.IsGemmDataTransposeBlock(); + if (info_.cube_info_.IsGemmDataTranspose()) { + bool transBlock = !info_.cube_info_.IsGemmDataTransposeInnerBlock(); + bool transIn = !info_.cube_info_.IsGemmDataTransposeBlock(); stmts[0] = TransposeLoopVarOrderInMad().Run(stmts[0], "_L1_local_L0A", transBlock, transIn); } - if (scop_.IsGemmWeightTranspose()) { - bool transBlock = !scop_.IsGemmWeightTransposeInnerBlock(); - bool transIn = !scop_.IsGemmWeightTransposeBlock(); + if (info_.cube_info_.IsGemmWeightTranspose()) { + bool transBlock = !info_.cube_info_.IsGemmWeightTransposeInnerBlock(); + bool transIn = !info_.cube_info_.IsGemmWeightTransposeBlock(); stmts[0] = TransposeLoopVarOrderInMad().Run(stmts[0], "_L1_local_L0B", transBlock, transIn); } } -void CCEIslEmitter::EmitAttrStmtL0(Tensor &t, bool &is_im2col, bool &is_filter_l0, bool &is_gemm_data_trans, - bool &is_gemm_weight_trans) { - if (scop_.is_spec_gemm_) { +void CCEIslEmitter::EmitReadAttrAtL0(std::vector &stmts, int i, Tensor &t) { + bool is_im2col = false; + bool is_filter_l0 = false; + bool is_gemm_data_trans = false; + bool is_gemm_weight_trans = false; + if (info_.cube_info_.IsSpecGemm()) { // this case is conv gemm if (t->op->name.find("_fractal_L1_local_L0A") != std::string::npos || t->op->name.find("_fractal_L1_local_L0B") != std::string::npos) { @@ -1424,10 +1520,10 @@ void CCEIslEmitter::EmitAttrStmtL0(Tensor &t, bool &is_im2col, bool &is_filter_l t->op->name.find("_local_L1_local_L0A") != std::string::npos) { is_filter_l0 = true; } - } else if (!scop_.is_spec_gemm_) { + } else { // this case is ordinary gemm - std::string data_trans = scop_.ExtractStringFromAttrsAndInfo(ATTR_GEMM_DATA_TRANSPOSE); - std::string weight_trans = scop_.ExtractStringFromAttrsAndInfo(ATTR_GEMM_WEIGHT_TRANSPOSE); + std::string data_trans = info_.cube_info_.ExtractStringFromAttrsAndInfo(ATTR_GEMM_DATA_TRANSPOSE); + std::string weight_trans = info_.cube_info_.ExtractStringFromAttrsAndInfo(ATTR_GEMM_WEIGHT_TRANSPOSE); size_t pos1 = t->op->name.find("_L1_local_L0A"); size_t pos2 = t->op->name.find("_L1_local_L0B"); if (data_trans == "Y" && pos1 != std::string::npos) { @@ -1445,11 +1541,33 @@ void CCEIslEmitter::EmitAttrStmtL0(Tensor &t, bool &is_im2col, bool &is_filter_l if (pos2 != std::string::npos) is_filter_l0 = true; } } + + if (is_im2col) { + stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_im2col", Expr(1), stmts[i]); + } else if (is_gemm_data_trans) { + stmts[i] = + AttrStmt::make(make_zero(Int(32)), "pragma_load2d_transpose_data", Expr(gemm_transpose_index_), stmts[i]); + gemm_transpose_index_++; + gemm_transpose_index_ = gemm_transpose_index_ % 2; + } else if (is_gemm_weight_trans) { + stmts[i] = + AttrStmt::make(make_zero(Int(32)), "pragma_load2d_transpose_weight", Expr(gemm_transpose_index_), stmts[i]); + gemm_transpose_index_++; + gemm_transpose_index_ = gemm_transpose_index_ % 2; + } + stmts[i] = ProducerConsumer::make(t->op, true, stmts[i]); + if (bypassL1_ > 0) { + if (is_filter_l0) { + stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_bypass_filter_l0", Expr(0), stmts[i]); + } + } } -void CCEIslEmitter::EmitAttrStmtL1(Tensor &t, bool &is_fractal, bool &is_filter_l1) { - std::string fractal_str = scop_.ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME) + "_fractal_L1"; - std::string filter_str = scop_.ExtractStringFromAttrs(ATTR_CONV_FILTER_NAME) + "_local_L1"; +void CCEIslEmitter::EmitReadAttrAtL1(std::vector &stmts, int i, Tensor &t) { + bool is_fractal = false; + bool is_filter_l1 = false; + std::string fractal_str = info_.cube_info_.ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME) + "_fractal_L1"; + std::string filter_str = info_.cube_info_.ExtractStringFromAttrs(ATTR_CONV_FILTER_NAME) + "_local_L1"; if (fractal_str == t->op->name) { is_fractal = true; @@ -1459,31 +1577,55 @@ void CCEIslEmitter::EmitAttrStmtL1(Tensor &t, bool &is_fractal, bool &is_filter_ is_filter_l1 = true; } - std::string data_str = scop_.ExtractStringFromAttrs(ATTR_CONV_GMM_FEATURE) + "_local_L1"; - std::string weight_str = scop_.ExtractStringFromAttrs(ATTR_CONV_GMM_WEIGHT) + "_local_L1"; + std::string data_str = info_.cube_info_.ExtractStringFromAttrs(ATTR_CONV_GMM_FEATURE) + "_local_L1"; + std::string weight_str = info_.cube_info_.ExtractStringFromAttrs(ATTR_CONV_GMM_WEIGHT) + "_local_L1"; if ((bypassL1_ == 2 && data_str == t->op->name) || (bypassL1_ == 1 && weight_str == t->op->name)) { is_filter_l1 = true; } + + if (is_fractal) { + stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_fractal", Expr(1), stmts[i]); + } + stmts[i] = ProducerConsumer::make(t->op, true, stmts[i]); + if (bypassL1_ > 0) { + if (is_filter_l1) { + stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_bypass_filter_l1", Expr(0), stmts[i]); + } + } +} + +void CCEIslEmitter::EmitReadAttr(const std::vector &read, std::vector &stmts, int i, bool is_L1, + bool is_L0) { + for (const auto &id : read[i]) { + Tensor t = info_.FindTensor(id); + if (is_L1) { + EmitReadAttrAtL1(stmts, i, t); + } + + if (is_L0) { + EmitReadAttrAtL0(stmts, i, t); + } + } } -void CCEIslEmitter::EmitAttrStmtLiveness(const Liveness &liveness, std::vector &stmts, int i, bool is_L1) { - for (const auto &id : liveness.write_[i]) { - if (is_L1 && scop_.IsCUB(id.get_name())) continue; - if (is_old_gemm_l1write_ && scop_.IsC(id.get_name())) { +void CCEIslEmitter::EmitWriteAttr(const std::vector &write, std::vector &stmts, int i, bool is_L1) { + for (const auto &id : write[i]) { + if (is_L1 && info_.cube_info_.IsCUB(id.get_name())) continue; + if (is_old_gemm_l1write_ && info_.cube_info_.IsC(id.get_name())) { stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_cube_l1write", Expr(1), stmts[i]); - scop_.old_l1_write_.emplace_back(stmts[i]); + info_.cube_info_.OldL1WriteInsert(stmts[i]); } - if (scop_.is_spec_gemm_ && scop_.IsC(id.get_name())) { + if (info_.cube_info_.IsSpecGemm() && info_.cube_info_.IsC(id.get_name())) { stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_cube_l0write", Expr(1), stmts[i]); cube_l0write_.emplace_back(stmts[i]); stmts[i] = Evaluate::make(0); } - if (scop_.IsGemm() && !scop_.is_spec_gemm_ && scop_.IsCUB(id.get_name())) { + if (info_.cube_info_.IsGemm() && !info_.cube_info_.IsSpecGemm() && info_.cube_info_.IsCUB(id.get_name())) { stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_cube_l0write", Expr(1), stmts[i]); cube_l0write_.emplace_back(stmts[i]); stmts[i] = Evaluate::make(0); } - if (scop_.IsGemm() && !scop_.is_spec_gemm_ && scop_.IsC(id.get_name())) { + if (info_.cube_info_.IsGemm() && !info_.cube_info_.IsSpecGemm() && info_.cube_info_.IsC(id.get_name())) { stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_cube_l1write", Expr(1), stmts[i]); if (!cube_l0write_.empty()) { cube_l0write_.emplace_back(Block::make(cube_l0write_[0], stmts[i])); @@ -1496,56 +1638,14 @@ void CCEIslEmitter::EmitAttrStmtLiveness(const Liveness &liveness, std::vector &stmts) { for (unsigned int i = 0; i < block_node.get_children().size(); ++i) { - for (const auto &id : liveness.read_[i]) { - Tensor t = scop_.FindTensor(id); - bool is_im2col = false; - bool is_fractal = false; - - bool is_filter_l1 = false; - bool is_filter_l0 = false; - - bool is_gemm_data_trans = false; - bool is_gemm_weight_trans = false; - if (is_L0) { - EmitAttrStmtL0(t, is_im2col, is_filter_l0, is_gemm_data_trans, is_gemm_weight_trans); - } - - if (is_L1) { - EmitAttrStmtL1(t, is_fractal, is_filter_l1); - } - - if (is_im2col) { - stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_im2col", Expr(1), stmts[i]); - } else if (is_fractal) { - stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_fractal", Expr(1), stmts[i]); - } else if (is_gemm_data_trans) { - stmts[i] = - AttrStmt::make(make_zero(Int(32)), "pragma_load2d_transpose_data", Expr(gemm_transpose_index_), stmts[i]); - gemm_transpose_index_++; - gemm_transpose_index_ = gemm_transpose_index_ % 2; - } else if (is_gemm_weight_trans) { - stmts[i] = - AttrStmt::make(make_zero(Int(32)), "pragma_load2d_transpose_weight", Expr(gemm_transpose_index_), stmts[i]); - gemm_transpose_index_++; - gemm_transpose_index_ = gemm_transpose_index_ % 2; - } - stmts[i] = ProducerConsumer::make(t->op, true, stmts[i]); - if (bypassL1_ > 0) { - if (is_filter_l1) { - stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_bypass_filter_l1", Expr(0), stmts[i]); - } - if (is_filter_l0) { - stmts[i] = AttrStmt::make(make_zero(Int(32)), "pragma_bypass_filter_l0", Expr(0), stmts[i]); - } - } - } - EmitAttrStmtLiveness(liveness, stmts, i, is_L1); + EmitReadAttr(liveness.read_, stmts, i, is_L1, is_L0); + EmitWriteAttr(liveness.write_, stmts, i, is_L1); } } -void CCEIslEmitter::EmitRealizeLivenessInfo(std::vector &real, const Liveness &liveness_info, - std::unordered_map, isl::IslIdIslHash> &liveness, - std::function const &CheckGoOut) { +void CCEIslEmitter::CollectLiveness(const Liveness &liveness_info, bool is_L1, std::vector &real, + std::unordered_map, isl::IslIdIslHash> &liveness, + std::function const &CheckGoOut) { for (unsigned int i = 0; i < liveness_info.read_.size(); i++) { IslIdSet idset; real.push_back(idset); @@ -1580,8 +1680,8 @@ void CCEIslEmitter::EmitRealizeLivenessInfo(std::vector &real, const L // Now we just judge whole loop's liveness from existing WAR. // It is correct in gemm, conv etc. but may be wrong in other cases. - std::string tensor_name = scop_.GetOriginTensorId(j).get_name(); - if (scop_.MayWriteAfterRead(tensor_name) && CheckGoOut(j.get_name())) { + std::string tensor_name = info_.GetOriginTensorId(j).get_name(); + if (info_.MayWriteAfterRead(tensor_name) && CheckGoOut(j.get_name())) { realize_out_.insert(j); } if (!liveness.count(j)) { @@ -1591,7 +1691,7 @@ void CCEIslEmitter::EmitRealizeLivenessInfo(std::vector &real, const L liveness.at(j).insert(v); } for (const auto &j : liveness_info.write_[i]) { - if (!scop_.IsInBinds(j) && CheckGoOut(j.get_name())) realize_out_.insert(j); + if (!info_.IsInBinds(j) && CheckGoOut(j.get_name())) realize_out_.insert(j); } // isolated part, may reuse def in full tile. We realize them out @@ -1599,6 +1699,13 @@ void CCEIslEmitter::EmitRealizeLivenessInfo(std::vector &real, const L if (CheckGoOut(j.get_name())) realize_out_.insert(j); } } + /// amazing and fusing control: which should be realized out + if (is_L1) realize_out_.clear(); + + for (const auto &i : liveness) { + if (realize_out_.count(i.first)) continue; + real[(unsigned int)*i.second.begin()].insert(i.first); + } } // add realize @@ -1606,26 +1713,18 @@ void CCEIslEmitter::EmitRealizeLivenessInfo(std::vector &real, const L // we hack gemm C+=A*B and make C's liveness in the whole loop void CCEIslEmitter::EmitRealize(const isl::ast_node_block &block_node, const Liveness &liveness_info, bool is_L1, bool is_L0, std::vector &stmts) { - std::vector real; - std::unordered_map, isl::IslIdIslHash> liveness; - auto c_ub = scop_.is_spec_gemm_ ? scop_.GetCName() : scop_.GetCName() + "_local_UB"; + auto c_ub = info_.cube_info_.IsSpecGemm() ? info_.cube_info_.GetCName() : info_.cube_info_.GetCName() + "_local_UB"; auto c_l0c = c_ub + "_local_L0C"; auto CheckGoOut = [&c_ub, &c_l0c](const std::string &id) -> bool { return !(id == c_ub || id == c_l0c); }; - EmitRealizeLivenessInfo(real, liveness_info, liveness, CheckGoOut); - - /// amazing and fusing control: which should be realized out - if (is_L1) realize_out_.clear(); - - for (const auto &i : liveness) { - if (realize_out_.count(i.first)) continue; - real[(unsigned int)*i.second.begin()].insert(i.first); - } + std::vector real; + std::unordered_map, isl::IslIdIslHash> liveness; + CollectLiveness(liveness_info, is_L1, real, liveness, CheckGoOut); size_t last = block_node.get_children().size() - 1; for (const auto &var : real[last]) { /// so far our alloc_C is only designed for specgemm - if (scop_.is_spec_gemm_ || scop_.IsConv()) { + if (info_.cube_info_.IsSpecGemm() || info_.cube_info_.IsConv()) { if (!CheckGoOut(var.get_name())) continue; } @@ -1638,16 +1737,16 @@ void CCEIslEmitter::EmitRealize(const isl::ast_node_block &block_node, const Liv for (const auto &var : real[p]) { /// so far our alloc_C is only designed for specgemm - if (scop_.is_spec_gemm_ || scop_.IsConv()) { + if (info_.cube_info_.IsSpecGemm() || info_.cube_info_.IsConv()) { if (!CheckGoOut(var.get_name())) continue; } stmts[p] = InsertRealize(stmts[p], var, is_L0); if (!DELETE_FRACTAL) continue; - std::string feature_str = scop_.ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME) + "_local_L1"; + std::string feature_str = info_.cube_info_.ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME) + "_local_L1"; if (feature_str == var.get_name()) { - std::string fractal_str = scop_.ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME) + "_fractal_L1"; + std::string fractal_str = info_.cube_info_.ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME) + "_fractal_L1"; stmts[p] = InsertRealize(stmts[p], isl::id(var.ctx(), fractal_str), is_L0); } } @@ -1711,7 +1810,7 @@ Stmt CCEIslEmitter::EmitBlock(const isl::ast_node_block &block_node) { } void CCEIslEmitter::ConvBackPropFilterFixMadInit(const isl::ast_node_mark &node, Expr &mad_init_cond) { - if (scop_.IsConvBackpropFilter()) { + if (info_.cube_info_.IsConvBackpropFilter()) { /// find reduce k; /// correct axles' name FindStmt fs = FindStmt(); @@ -1723,19 +1822,19 @@ void CCEIslEmitter::ConvBackPropFilterFixMadInit(const isl::ast_node_mark &node, CHECK(usr_expr.get_arg(0).isa()); isl::id curstmtid = usr_expr.get_arg(0).as().get_id(); isl::id curnodeid = i.get_annotation(); - const Node *stmt_node = scop_.data_.statements.at(curstmtid); + const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(curstmtid); CHECK(stmt_node != nullptr); // stmt_node should not have if stmt if (stmt_node->IsInstance()) { auto build = node_info_map_.at(curnodeid).build; - auto tuple = scop_.data_.domains.at(curstmtid).tuple; + auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(curstmtid).tuple; auto iterator_map = node_info_map_.at(curnodeid).iterator_map; for (unsigned int n = 0; n < tuple.size(); n++) { isl::id isl_old_iter = tuple.get_id(n); bool is_red = false; - for (const auto &reds : scop_.data_.reduces) { + for (const auto &reds : info_.analysis_result_.GetReduceMap()) { for (auto j : reds.second) { // when support atomic add, "no" should not init in each core if (isl_old_iter.get_name() == j->var->name_hint && isl_old_iter.get_name() != "no") { @@ -1767,7 +1866,7 @@ void CCEIslEmitter::ConvBackPropFilterFixMadInit(const isl::ast_node_mark &node, Stmt CCEIslEmitter::EmitMarkFuseVector(const isl::ast_node_mark &node) { auto stmt = AttrStmt::make(make_zero(Int(32)), "pragma_fuse_vector", Expr(1), EmitAst(node.get_node())); - if (scop_.IsGemm() && !scop_.is_spec_gemm_ && !cube_l0write_.empty()) { + if (info_.cube_info_.IsGemm() && !info_.cube_info_.IsSpecGemm() && !cube_l0write_.empty()) { cube_l0write_.emplace_back(Block::make(cube_l0write_[0], stmt)); stmt = Evaluate::make(0); } @@ -1787,24 +1886,24 @@ Stmt CCEIslEmitter::EmitMarkAllocRealizeOut(const isl::ast_node_mark &node) { Stmt CCEIslEmitter::EmitMarkAllocC(const isl::ast_node_mark &node) { Stmt body = EmitAst(node.get_node()); body = RemoveNoOp(body); - body = HoistL0write(scop_, body, cube_l0write_); + body = HoistL0write(info_, body, cube_l0write_); - auto c_ub = scop_.is_spec_gemm_ ? scop_.GetCName() : scop_.GetCName() + "_local_UB"; + auto c_ub = info_.cube_info_.IsSpecGemm() ? info_.cube_info_.GetCName() : info_.cube_info_.GetCName() + "_local_UB"; auto c_l0c = c_ub + "_local_L0C"; - body = InsertRealize(body, isl::id(scop_.ctx_, c_l0c), false); - body = InsertRealize(body, isl::id(scop_.ctx_, c_ub), false); + body = InsertRealize(body, isl::id(info_.GetCtx(), c_l0c), false); + body = InsertRealize(body, isl::id(info_.GetCtx(), c_ub), false); body = AttrStmt::make(make_zero(Int(32)), ALLOC_C, Expr(1), body); return body; } Stmt CCEIslEmitter::EmitMarkSpecGemm(const isl::ast_node_mark &node) { - scop_.UpdateFractalIntInfo(++isolate_idx_); + info_.cube_info_.UpdateFractalIntInfo(++isolate_idx_); Expr mad_init_cond; ConvBackPropFilterFixMadInit(node, mad_init_cond); - if (scop_.out_reduce_init_ == 0) { + if (info_.cube_info_.GetOutReduceInit() == 0) { mad_init_cond = Expr(0); } - Stmt stmt = scop_.ConstructPolyGemm(mad_init_cond); + Stmt stmt = SpecGemmBuilder(info_).Build(mad_init_cond); return EmitSpecGemL1write(node, stmt); } @@ -1847,10 +1946,11 @@ void CCEIslEmitter::RealizeOut() { // Now we just judge whole loop's liveness from existing WAR. // It is correct in gemm, conv etc. but may be wrong in other cases. - std::string tensor_name = scop_.GetOriginTensorId(j).get_name(); - if (scop_.MayWriteAfterRead(tensor_name)) { + std::string tensor_name = info_.GetOriginTensorId(j).get_name(); + if (info_.MayWriteAfterRead(tensor_name)) { bool do_out = true; - auto c_ub = scop_.is_spec_gemm_ ? scop_.GetCName() : scop_.GetCName() + "_local_UB"; + auto c_ub = + info_.cube_info_.IsSpecGemm() ? info_.cube_info_.GetCName() : info_.cube_info_.GetCName() + "_local_UB"; auto c_l0c = c_ub + "_local_L0C"; if (j.get_name() == c_ub || j.get_name() == c_l0c) { do_out = false; @@ -1897,7 +1997,7 @@ Stmt CCEIslEmitter::EmitMarkMulticore(const isl::ast_node_mark &node) { } else if (node.get_node().as()) { auto stmt = EmitAst(node.get_node()); for (const auto &var : realize_must_def_) { - Tensor t = scop_.FindTensor(var); + Tensor t = info_.FindTensor(var); Region bounds; for (auto j : t->shape) { bounds.push_back(Range::make_by_min_extent(Expr(0), j)); @@ -2386,28 +2486,29 @@ Stmt CCEIslEmitter::Emit(const isl::ast_node &node) { Stmt stmt = EmitAst(node); stmt = RemoveCond(stmt); /// emit global realize - if (!scop_.is_spec_gemm_) { + if (!info_.cube_info_.IsSpecGemm()) { for (const auto &i : global_realize_out_) { - Tensor t = scop_.FindTensor(i); + Tensor t = info_.FindTensor(i); if (realized_.count(t)) continue; stmt = InsertRealize(stmt, i, false); } for (const auto &i : realize_out_) { - Tensor t = scop_.FindTensor(i); + Tensor t = info_.FindTensor(i); if (realized_.count(t)) continue; stmt = InsertRealize(stmt, i, false); } - for (const auto &i : scop_.realize_from_input_) { - if (FindUsingTensor(stmt).found(i.get_name()) && !scop_.IsInBinds(i.get_name())) { + auto realize_from_input = info_.user_config_.GetRealizeFromInput(); + for (const auto &i : realize_from_input) { + if (FindUsingTensor(stmt).found(i.get_name()) && !info_.IsInBinds(i.get_name())) { stmt = InsertRealize(stmt, i, false); } } auto not_realized_tensors = FindNotRealizedTensors().Find(stmt); for (const auto ¬_realized_tensor : not_realized_tensors) { - isl::id var = isl::id(scop_.ctx_, not_realized_tensor); + isl::id var = isl::id(info_.GetCtx(), not_realized_tensor); if (!FindRealizeScopeToString(var).empty()) { // The tensor needs to be realized somewhere, but it is not realized in the correct scope. // So, we realize it in the outermost scope to fix it. @@ -2427,6 +2528,59 @@ Stmt CCEIslEmitter::Emit(const isl::ast_node &node) { } return stmt; } + +void GetNameWithoutLocal(isl::id &tensor_id, ScopInfo &info) { + if (!info.cube_info_.IsSpecGemm()) { + size_t pos = tensor_id.get_name().find("_local_"); + std::string substr = tensor_id.get_name().substr(0, pos); + if (pos != 0) tensor_id = isl::id(tensor_id.ctx(), substr); + } +} + +isl::multi_aff CCEIslEmitter::TensorAccessMultAff(isl::id &tensor_id, const Array &tensor_index, + const isl::id &node_id) { + GetNameWithoutLocal(tensor_id, info_); + return IslEmitter::TensorAccessMultAff(tensor_id, tensor_index, node_id); +} + +bool CCEIslEmitter::IsCopyinFromAnotherBand(isl::multi_aff &access) { + if (!info_.cube_info_.IsSpecGemm()) { + return IslEmitter::IsCopyinFromAnotherBand(access); + } + return false; +} + +bool CCEIslEmitter::IsTransferStmt() { + if (!info_.cube_info_.IsSpecGemm()) { + return IslEmitter::IsTransferStmt(); + } + return false; +} + +Stmt CCEIslEmitter::EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, + BufferedFootPrintInfo &buffer_footprint_info) { + const Call *call = static_cast(node); + Array args; + for (auto iv : call->args) { + args.push_back(ReplaceLoopVar(var_map_tmp).Mutate(iv)); + } + // Not hoisted, emitting just the mapped subscript. + if (!buffer_footprint_info.cluster_id) { + std::string call_name = call->name; + if (IsTransferStmt() && (std::string::npos == call_name.find("_local_UB"))) { + call_name = call_name + "_local_UB"; + Tensor t = info_.FindTensor(call_name); + if (t.defined()) { + return Evaluate::make(Call::make(call->type, call_name, args, call->call_type, t->op, call->value_index)); + } else { + LOG(WARNING) << "Call can not found tensor!!! tensor name: " << call_name; + } + } + return Evaluate::make(Call::make(call->type, call->name, args, call->call_type, call->func, call->value_index)); + } + return Stmt(); +} + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/cce_isl_emitter.h b/src/poly/cce_isl_emitter.h index 0a8f035d7c811b953ebe56977a39f6507424c8bd..7bd9a0ff206f96f330992ec8f89a4af0ee0edd5c 100644 --- a/src/poly/cce_isl_emitter.h +++ b/src/poly/cce_isl_emitter.h @@ -16,13 +16,7 @@ #ifndef POLY_CCE_ISL_EMITTER_H_ #define POLY_CCE_ISL_EMITTER_H_ -#include -#include -#include - #include "ir_pass.h" -#include "isl.h" -#include "scop.h" #include "isl_emitter.h" namespace akg { @@ -39,12 +33,15 @@ class Liveness { std::vector read_; std::vector write_; }; +enum AtomicType { Equ = 0, Add }; /*! * IslEmitter for CCE */ class CCEIslEmitter : public IslEmitter { public: - CCEIslEmitter(Scop &s, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(s, n, i) { ProcBypassL1(s); } + CCEIslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(info, n, i) { + ProcBypassL1(info); + } ~CCEIslEmitter() override = default; Stmt Emit(const isl::ast_node &node) final; @@ -52,7 +49,6 @@ class CCEIslEmitter : public IslEmitter { private: // override emitters for CCE Stmt EmitFor(const isl::ast_node_for &node) final; - Stmt EmitIf(const isl::ast_node_if &node) final; Stmt EmitMark(const isl::ast_node_mark &node_id) override; Stmt EmitBlock(const isl::ast_node_block &node) final; Stmt EmitStmt(const isl::ast_node_user &node) final; @@ -60,60 +56,68 @@ class CCEIslEmitter : public IslEmitter { // DMA emitters for CCE Expr EmitLoad(const isl::ast_expr &lhs, Type type); - Stmt EmitL1Read(const isl::ast_node_user &node); - Stmt EmitL1Write(const isl::ast_node_user &node, Scop::AtomicType atomic); + Stmt EmitRead(const isl::ast_node_user &node); + Stmt EmitWrite(const isl::ast_node_user &node, AtomicType atomic); Stmt EmitSpecGemL1write(const isl::ast_node_mark &node, const Stmt &stmt); - // RangeInfo emitters for CCE - Stmt EmitGemmRangeInfoBackPropFilter(const Stmt &stmt); - Stmt EmitGemmRangeInfo(Stmt stmt); - - // multicore emitters for CCE + // emit mark node Stmt EmitMarkMulticore(const isl::ast_node_mark &node); - bool InjectMulticore(const std::string &iter); - Stmt EmitMarkFuseVector(const isl::ast_node_mark &node); Stmt EmitMarkAllocRealizeOut(const isl::ast_node_mark &node); Stmt EmitMarkAllocC(const isl::ast_node_mark &node); Stmt EmitMarkSpecGemm(const isl::ast_node_mark &node); + // emit attrs void EmitAttrStmt(const isl::ast_node_block &block_node, const Liveness &liveness, bool is_L1, bool is_L0, std::vector &stmts); - - void EmitAttrStmtL0(Tensor &t, bool &is_im2col, bool &is_filter_l0, bool &is_gemm_data_trans, - bool &is_gemm_weight_trans); - - void EmitAttrStmtL1(Tensor &t, bool &is_fractal, bool &is_filter_l1); - - void EmitAttrStmtLiveness(const Liveness &liveness, std::vector &stmts, int i, bool is_L1); - + void EmitReadAttrAtL0(std::vector &stmts, int i, Tensor &t); + void EmitReadAttrAtL1(std::vector &stmts, int i, Tensor &t); + void EmitReadAttr(const std::vector &read, std::vector &stmts, int i, bool is_L1, bool is_L0); + void EmitWriteAttr(const std::vector &write, std::vector &stmts, int i, bool is_L1); void EmitAttrStmtAfterRealize(bool is_L1, bool is_L0, std::vector &stmts); + Stmt EmitGemmRangeInfoBackPropFilter(const Stmt &stmt); + Stmt EmitGemmRangeInfo(Stmt stmt); + + // emit realize void EmitRealize(const isl::ast_node_block &block_node, const Liveness &liveness_info, bool is_L1, bool is_L0, std::vector &stmts); - void EmitRealizeLivenessInfo(std::vector &real, const Liveness &liveness_info, - std::unordered_map, isl::IslIdIslHash> &liveness, - std::function const &CheckGoOut); - void EmitGemmRangeInfoNewAxis(std::vector &range, std::vector &prefix, - std::unordered_map &outerAxis, Range &axisMRange, - Map &range_map, Map &axis_map); + // emit access + Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info) override; - void EmitGemmRangeInfoDynamic(Range &axisMRange, Map &range_map); - void EmitGemmRangeInfoStatic(Map &range_map); - // realize info for CCE + // tool func + bool InjectMulticore(const std::string &iter); + void CollectLiveness(const Liveness &liveness_info, bool is_L1, std::vector &real, + std::unordered_map, isl::IslIdIslHash> &liveness, + std::function const &CheckGoOut); + void CollectGemmRangeInfoNewAxis(std::vector &range, std::vector &prefix, + std::unordered_map &outerAxis, Range &axisMRange, + Map &range_map, Map &axis_map); + + void CollectGemmMWSize(Range &axis_m_range, Map &range_map); + void CollectGemmMWSizeDynamic(Map &range_map); Expr FindRealizeScope(const isl::id &var); std::string FindRealizeScopeToString(const isl::id &var); Stmt InsertRealize(Stmt stmt, const isl::id &var, bool is_L0); void RealizeOut(); Stmt RemoveCond(const Stmt &stmt); - void ProcBypassL1(const Scop &scop); + void ProcBypassL1(const ScopInfo &info); void SetCube(const isl::id &stmt_id); std::string ReplaceAxis(const std::string &oldAxis); static std::vector ConstructPrefix(); void GemmTranspose(std::vector &stmts); void ConvBackPropFilterFixMadInit(const isl::ast_node_mark &node, Expr &mad_init_cond); + isl::multi_aff TensorAccessMultAff(isl::id &tensor_id, const Array &subscripts, + const isl::id &stmt_id) override; + bool IsTransferStmt() override; + bool IsCopyinFromAnotherBand(isl::multi_aff &access) override; + + std::map iters_old_name_; + std::map iters_new_name_; + std::unordered_map stmt_var_map_; + std::set realized_; IslIdSet hoisted_read_; IslIdSet hoisted_write_; diff --git a/src/poly/construct_poly_accesses.cc b/src/poly/construct_poly_accesses.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6b51c48d61b4f47065e854e22e415b310a654fe --- /dev/null +++ b/src/poly/construct_poly_accesses.cc @@ -0,0 +1,562 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "poly/construct_poly_accesses.h" + +#include +#include + +#include +#include + +#include "poly/scop_builder.h" +#include "pass/utils.h" + +namespace akg { +namespace ir { +namespace poly { + +std::pair ConstructPolyAccess(const OperatorDomainSpace &domain, const Node *op, + const std::string &tensor, const Array &dimensions, + AccessMap &accesses) { + // create a tensor coordinate to store the accessed relation + auto coordinate = + CollectTensorCoordinate(domain.param_space, isl::id(domain.param_space.ctx(), tensor), dimensions.size()); + auto tensor_space = coordinate.get_space(); + + // create a fully access set + isl::set tensor_access = isl::set::universe(tensor_space); + + // add access relation constraint for each parameter of one dimension + auto identity = isl::multi_aff::identity(tensor_space.map_from_set()); + for (size_t dim_idx = 0; dim_idx < dimensions.size(); ++dim_idx) { + // make aff bounds of each dimension. + auto domain_aff_bounds = Expr2Aff(domain.param_space, dimensions[dim_idx]); + if (!domain_aff_bounds.is_null()) { + domain_aff_bounds = domain_aff_bounds.unbind_params_insert_domain(coordinate); + tensor_access = tensor_access.intersect(domain_aff_bounds.eq_set(identity.get_aff(static_cast(dim_idx)))); + } + } + + auto tensor_map = + AddSuffix4Accesses(accesses, tensor_access.unbind_params_insert_domain(domain.tuple), op, domain.param_space.ctx()); + + return {tensor_map, isl::map::from(identity)}; +} + +std::tuple ConstructPolyAccesses(const OperatorDomainSpace &domain, + const Stmt &s, AccessMap &accesses) { + class AttrsExtractor final : public IRVisitor { + public: + AttrsExtractor() {} + ~AttrsExtractor() override = default; + + void Apply(const Stmt &s) { IRVisitor::Visit(s); } + + void Visit_(const AttrStmt *op) override { + if (op->attr_key == ATTR_IM2COL_KEY) { + Map var_map = Downcast>(op->node); + for (auto item : var_map) { + if (item.first == ATTR_PRAGMA_OUT_H) { + m_out_h = item.second.as() != nullptr ? static_cast(item.second.as()->value) : 0; + } else if (item.first == ATTR_PRAGMA_OUT_W) { + m_out_w = item.second.as() != nullptr ? static_cast(item.second.as()->value) : 0; + } + } + } + IRVisitor::Visit_(op); + } + + void Visit_(const Evaluate *op) override { + CHECK(op); + const int im2_col_arg_num = 23; + enum Im2colCallIndex { + idxStrideH = 7, + idxStrideW, + idxKernelH, + idxKernelW, + idxPadTop = 17, + idxPadBottom, + idxPadLeft, + idxPadRight + }; + const Call *call = op->value.as(); + CHECK(call); + auto getCallValue = [&call](const Im2colCallIndex &idx) { + if (auto item = call->args[static_cast(idx)].as()) { + return static_cast(item->value); + } + return 0; + }; + if (call->name == CALL_IM2COL_UB && call->args.size() == im2_col_arg_num) { + m_strid_h = getCallValue(Im2colCallIndex::idxStrideH); + m_strid_w = getCallValue(Im2colCallIndex::idxStrideW); + m_kernel_h = getCallValue(Im2colCallIndex::idxKernelH); + m_kernel_w = getCallValue(Im2colCallIndex::idxKernelW); + m_pad_top = getCallValue(Im2colCallIndex::idxPadTop); + m_pad_bottom = getCallValue(Im2colCallIndex::idxPadBottom); + m_pad_left = getCallValue(Im2colCallIndex::idxPadLeft); + m_pad_right = getCallValue(Im2colCallIndex::idxPadRight); + } + IRVisitor::Visit_(op); + } + + int KernelH() const { return m_kernel_h; } + + int KernelW() const { return m_kernel_w; } + int OutH() const { return m_out_h; } + int OutW() const { return m_out_w; } + int StrideH() const { return m_strid_h; } + int StrideW() const { return m_strid_w; } + int PadLeft() const { return m_pad_left; } + int PadRight() const { return m_pad_right; } + int PadTop() const { return m_pad_top; } + int PadBottom() const { return m_pad_bottom; } + + private: + int m_kernel_h{0}; + int m_kernel_w{0}; + int m_out_h{0}; + int m_out_w{0}; + int m_strid_h{0}; + int m_strid_w{0}; + int m_pad_left{0}; + int m_pad_right{0}; + int m_pad_top{0}; + int m_pad_bottom{0}; + }; + class RelationAccessesParser final : public IRVisitor { + public: + isl::map ExtractIm2ColReadAccess(const std::string &tensor, const Array &shape) { + const int arg_num = shape.size(); + isl::space param_space = domain.param_space; + isl::id tensor_id(param_space.ctx(), tensor); + auto coordinate = CollectTensorCoordinate(param_space, tensor_id, arg_num); + auto tensor_space = coordinate.get_space(); + + isl::set access = isl::set::universe(tensor_space); + auto identity = isl::multi_aff::identity(tensor_space.map_from_set()); + // need to optimize automatic add this exprs + Array args; + auto arg_size = static_cast(param_space.dim(isl_dim_param)); + int k_h = extractor.KernelH(); + int k_w = extractor.KernelW(); + int o_h = extractor.OutH(); + int o_w = extractor.OutW(); + if (arg_size == 3) { + CHECK(shape[0].as()); + args.push_back(shape[0].as()->value > 0 ? static_cast(Var("i")) : Expr(0)); + } else { + args.push_back(VarExpr("j") * Expr(16) / Expr(o_h * o_w)); + } + VarExpr k("k"); + CHECK_GT(k_h, 0); + CHECK_GT(k_w, 0); + Expr v = k / Expr(k_h * k_w); + args.push_back(v); + for (size_t i = 0; i < args.size(); ++i) { + auto range_point = identity.get_aff(static_cast(i)); + auto domain_point = Expr2Aff(param_space, args[i]); + if (!domain_point.is_null()) { + domain_point = domain_point.unbind_params_insert_domain(coordinate); + access = access.intersect(domain_point.eq_set(range_point)); + } + } + auto map = access.unbind_params_insert_domain(domain.tuple); + + std::string tag = "__poly_ref_0"; + isl::id tag_id(domain.param_space.ctx(), tag); + auto domain_space = map.get_space().domain(); + auto tag_space = domain_space.params().add_named_tuple_id_ui(tag_id, 0); + domain_space = domain_space.product(tag_space).unwrap(); + map = map.preimage_domain(isl::multi_aff::domain_map(domain_space)); + enum FeatureMapIndex { kBatchIndex = 0, kC1Index, kHIndex, kWIndex, kC0Index, KFeatureMapSiz }; + + CHECK_EQ(shape.size(), FeatureMapIndex::KFeatureMapSiz); + isl::set range = map.range(); + /*********************** + * no cut in H axis + * 0<= arg2 <= fm_h-1 + * 0<= arg3 <= fm_w-1 + * 0<= arg4 <= 16-1 + ************************/ + if (arg_size == 2) { + range = range.lower_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kBatchIndex), 0); + CHECK(shape[static_cast(FeatureMapIndex::kBatchIndex)].as()); + range = range.upper_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kBatchIndex), + shape[static_cast(FeatureMapIndex::kBatchIndex)].as()->value - 1); + } + CHECK(shape[static_cast(FeatureMapIndex::kHIndex)].as() && + shape[static_cast(FeatureMapIndex::kWIndex)].as() && + shape[static_cast(FeatureMapIndex::kC0Index)].as()); + + range = range.lower_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kHIndex), 0); + range = range.upper_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kHIndex), + shape[static_cast(FeatureMapIndex::kHIndex)].as()->value - 1); + range = range.lower_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kWIndex), 0); + range = range.upper_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kWIndex), + shape[static_cast(FeatureMapIndex::kWIndex)].as()->value - 1); + range = range.lower_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kC0Index), 0); + range = range.upper_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kC0Index), + shape[static_cast(FeatureMapIndex::kC0Index)].as()->value - 1); + + map = map.intersect_range(range); + + return map; + } + + bool UpdateAccess(const Array &shape) const { + const size_t kHIndex = 2; + const int largeHSize = 200; + Expr fm_h = shape[kHIndex]; + if (extractor.PadTop() > 0 && extractor.PadBottom() > 0 && extractor.PadLeft() > 0 && extractor.PadRight() > 0 && + Compare(fm_h, Expr(largeHSize)) > 0) { + return true; + } + return false; + } + + std::string getConstraint(const std::string &min_j, const std::string &max_j, const std::string &min_h, + const std::string &max_h) { + std::ostringstream ss; + ss << "(" << min_j << " <= j <= " << max_j << " and " << min_h << " <= arg2 <= " << max_h << ")"; + std::string set_con = ss.str(); + return set_con; + } + + std::string toString(int i) { + std::ostringstream ss; + ss << i; + return ss.str(); + } + + std::string body(bool left) { + std::ostringstream ss; + if (left) { + ss << extractor.StrideH() << "j/" << extractor.KernelH() << " - " << extractor.PadLeft(); + } else { + ss << extractor.StrideH() << "j/" << extractor.KernelH() << " + " << extractor.PadRight(); + } + return ss.str(); + } + + void UpdatePaddingConstraint(const Expr &fmH) { + int size_h = 0; + if (fmH.as()) { + size_h = static_cast(fmH.as()->value); + } + const int mi = 16; + const int cut_h = 2; + int size_m = extractor.OutH() * extractor.OutW() / mi; + int head_m = cut_h * extractor.OutW() / mi; + + int head_h = extractor.KernelH() + (cut_h - 1) * extractor.StrideH() - extractor.PadTop() - 1; + int tail_h = (extractor.OutH() - cut_h) * extractor.StrideH() - extractor.PadTop(); + + std::string head_con = getConstraint(toString(0), toString(head_m - 1), toString(0), toString(head_h)); + std::string tail_con = + getConstraint(toString(size_m - head_m), toString(size_m - 1), toString(tail_h), toString(size_h - 1)); + std::string body_con = getConstraint(toString(head_m), toString(size_m - head_m - 1), body(true), body(false)); + + auto map_str = reads.to_str(); + std::string constraint = " (" + head_con + " or " + body_con + " or " + tail_con + ") "; + size_t endPos = map_str.find("}"); + std::string main = map_str.substr(0, endPos); + main = main + " and " + constraint + " }"; + isl_union_map *read_tmp = isl_union_map_read_from_str(reads.ctx().get(), main.c_str()); + CHECK(read_tmp); + reads = isl::manage(read_tmp); + } + + isl::map ExtractIm2ColWriteAccess(const std::string &tensor, const Array &shape) { + int arg_num = shape.size(); + isl::space param_space = domain.param_space; + isl::id tensor_id(param_space.ctx(), tensor); + auto coordinate = CollectTensorCoordinate(param_space, tensor_id, arg_num); + auto tensor_space = coordinate.get_space(); + + isl::set access = isl::set::universe(tensor_space); + auto identity = isl::multi_aff::identity(tensor_space.map_from_set()); + // need to optimize automatic add this exprs + auto arg_size = static_cast(param_space.dim(isl_dim_param)); + Array args; + const std::vector consStr5D = {"i", "j", "k", "mi", "ni"}; + const std::vector consStr4D = {"j", "k", "mi", "ni"}; + enum ShapeDim { shape5D = 0, shape4D }; + ShapeDim mod = ShapeDim::shape5D; + if (consStr5D.size() == shape.size()) { + mod = ShapeDim::shape5D; + for (size_t i = 0; i < arg_size; ++i) { + if (i == 0) { + CHECK(shape[0].as()); + Expr e = shape[0].as()->value > 0 ? static_cast(Var(consStr5D[i])) : Expr(0); + args.push_back(e); + } else { + args.push_back(static_cast(Var(consStr5D[i]))); + } + } + } else if (consStr4D.size() == shape.size()) { + mod = ShapeDim ::shape4D; + for (size_t i = 0; i < arg_size; ++i) { + args.push_back(static_cast(Var(consStr4D[i]))); + } + } + + for (size_t i = 0; i < args.size(); ++i) { + auto range_point = identity.get_aff(static_cast(i)); + auto domain_point = Expr2Aff(param_space, args[i]); + if (!domain_point.is_null()) { + domain_point = domain_point.unbind_params_insert_domain(coordinate); + access = access.intersect(domain_point.eq_set(range_point)); + } + } + + auto map = access.unbind_params_insert_domain(domain.tuple); + + std::string tag = "__poly_ref_1"; + isl::id tag_id(domain.param_space.ctx(), tag); + auto domain_space = map.get_space().domain(); + auto tag_space = domain_space.params().add_named_tuple_id_ui(tag_id, 0); + domain_space = domain_space.product(tag_space).unwrap(); + map = map.preimage_domain(isl::multi_aff::domain_map(domain_space)); + + enum FractalIndex { idxBatch = 0, idxMo, idxKo, idxMi, idxKi, fractalSize }; + /*********************** + * mi ni range definition + * 0<= arg3 <= 16-1 + * 0<= arg4 <= 16-1 + ************************/ + CHECK_EQ(shape.size(), FractalIndex::fractalSize - mod); + CHECK(shape[static_cast(FractalIndex::idxMi - mod)].as() && + shape[static_cast(FractalIndex::idxKi - mod)].as()); + isl::set range = map.range(); + + range = range.lower_bound_si(isl_dim_set, static_cast(FractalIndex::idxMi - mod), 0); + range = range.upper_bound_si(isl_dim_set, static_cast(FractalIndex::idxMi - mod), + shape[static_cast(FractalIndex::idxMi - mod)].as()->value - 1); + + range = range.lower_bound_si(isl_dim_set, static_cast(FractalIndex::idxKi - mod), 0); + range = range.upper_bound_si(isl_dim_set, static_cast(FractalIndex::idxKi - mod), + shape[static_cast(FractalIndex::idxKi - mod)].as()->value - 1); + map = map.intersect_range(range); + + return map; + } + + void Visit_(const Evaluate *op) final { + IRVisitor::Visit_(op); + const Call *call_op = op->value.as(); + if (call_op && call_op->name == CALL_IM2COL_UB) { + CHECK_GE(call_op->args.size(), 2); + CHECK(call_op->args[0].as()); + CHECK_GE(call_op->args[0].as()->args.size(), 2); + CHECK(call_op->args[0].as()->args[1].as()); + CHECK(call_op->args[1].as()); + CHECK_GE(call_op->args[1].as()->args.size(), 2); + CHECK(call_op->args[1].as()->args[1].as()); + std::string write_buffer = call_op->args[0].as()->args[1].as()->name_hint; + std::string read_buffer = call_op->args[1].as()->args[1].as()->name_hint; + for (auto item : accesses) { + if (item.first->IsInstance()) { + auto attr = static_cast(item.first); + Array array = Downcast>(attr->node); + Buffer buffer = Downcast(array[0]); + Tensor tensor = Downcast(array[1]); + if (buffer->name == read_buffer) { + isl::map readIm2Col = ExtractIm2ColReadAccess(tensor->op->name, tensor->shape); + reads = reads.unite(readIm2Col); + if (UpdateAccess(tensor->shape)) { + UpdatePaddingConstraint(tensor->shape[2]); + } + } else if (buffer->name == write_buffer) { + isl::map writeIm2Col = ExtractIm2ColWriteAccess(tensor->op->name, tensor->shape); + writes = writes.unite(writeIm2Col); + } + } + } + } + } + + void Visit_(const Call *op) final { + IRVisitor::Visit_(op); + if (op->call_type == Call::Halide) { + isl::map reads_tmp, toinner_tmp; + std::string var_name = op->name; + if (op->func.defined() && op->func->num_outputs() != 1) { + var_name = var_name + "_v" + std::to_string(op->value_index); + } + std::tie(reads_tmp, toinner_tmp) = ConstructPolyAccess(domain, op, var_name, op->args, accesses); + reads = reads.unite(reads_tmp); + to_inner_ = to_inner_.add_map(toinner_tmp); + } + } + + void Visit_(const Provide *op) final { + IRVisitor::Visit_(op); + isl::map writes_tmp, toinner_tmp; + std::string var_name = op->func->func_name(); + if (op->func->num_outputs() != 1) { + var_name = var_name + "_v" + std::to_string(op->value_index); + } + std::tie(writes_tmp, toinner_tmp) = ConstructPolyAccess(domain, op, var_name, op->args, accesses); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.add_map(toinner_tmp); + } + + /* The conditionals of IfThenElse statements may fall in these cases. + * The accesses should be updated to read sets of scop as such accesses + * may only be read. + * + * More complicated cases like conditionals involving Store and/or + * Provide should also update write sets. + */ + void Visit_(const EQ *op) final { + isl::union_map reads_tmp, writes_tmp, toinner_tmp; + + Stmt stmt_a(GetObjPtr(op->a.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + + Stmt stmt_b(GetObjPtr(op->b.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + } + + void Visit_(const NE *op) final { + isl::union_map reads_tmp, writes_tmp, toinner_tmp; + + Stmt stmt_a(GetObjPtr(op->a.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + + Stmt stmt_b(GetObjPtr(op->b.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + } + + void Visit_(const LT *op) final { + isl::union_map reads_tmp, writes_tmp, toinner_tmp; + + Stmt stmt_a(GetObjPtr(op->a.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + + Stmt stmt_b(GetObjPtr(op->b.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + } + + void Visit_(const LE *op) final { + isl::union_map reads_tmp, writes_tmp, toinner_tmp; + + Stmt stmt_a(GetObjPtr(op->a.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + + Stmt stmt_b(GetObjPtr(op->b.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + } + + void Visit_(const GT *op) final { + isl::union_map reads_tmp, writes_tmp, toinner_tmp; + + Stmt stmt_a(GetObjPtr(op->a.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + + Stmt stmt_b(GetObjPtr(op->b.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + } + + void Visit_(const GE *op) final { + isl::union_map reads_tmp, writes_tmp, toinner_tmp; + + Stmt stmt_a(GetObjPtr(op->a.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + + Stmt stmt_b(GetObjPtr(op->b.get())); + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + } + + // End of conditionals of IfThenElse, more cases are pending. + + /* A For type statement may be visited in the presence of + * IfThenElse in the scop, as the body of the enclosing + * if statement. + * + * A Block type should be handled. + */ + + void Visit_(const For *op) final { + IRVisitor::Visit_(op); + isl::union_map reads_tmp, writes_tmp, toinner_tmp; + + std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, op->body, accesses); + reads = reads.unite(reads_tmp); + writes = writes.unite(writes_tmp); + to_inner_ = to_inner_.unite(toinner_tmp); + } + + const OperatorDomainSpace &domain; + AccessMap &accesses; + + isl::union_map reads, writes; + isl::union_map to_inner_; + AttrsExtractor extractor; + + RelationAccessesParser(const Stmt stmt, const OperatorDomainSpace &space, AccessMap &accesses) + : domain(space), + accesses(accesses), + reads(isl::union_map::empty(domain.tuple.get_space())), + writes(isl::union_map::empty(domain.tuple.get_space())), + to_inner_(isl::union_map::empty(domain.tuple.get_space())) { + extractor.Apply(stmt); + IRVisitor::Visit(stmt); + } + ~RelationAccessesParser() override = default; + } parser(s, domain, accesses); + return std::make_tuple(parser.reads, parser.writes, parser.to_inner_); +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/construct_poly_accesses.h b/src/poly/construct_poly_accesses.h new file mode 100644 index 0000000000000000000000000000000000000000..03f1966bd49ec975ec85048989cc42ce9811024b --- /dev/null +++ b/src/poly/construct_poly_accesses.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POLY_CONSTRUCT_POLY_ACCESSES_H_ +#define POLY_CONSTRUCT_POLY_ACCESSES_H_ + +#include +#include + +#include + +#include "poly/scop_info.h" + +namespace akg { +namespace ir { +namespace poly { + +std::pair ConstructPolyAccess(const OperatorDomainSpace &domain, const Node *op, + const std::string &tensor, const Array &dimensions, + AccessMap &accesses); + +std::tuple ConstructPolyAccesses(const OperatorDomainSpace &domain, + const Stmt &s, AccessMap &accesses); +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_CONSTRCUT_POLY_ACCESSES_H_ \ No newline at end of file diff --git a/src/poly/cce_optimizer.cc b/src/poly/davinci_halide_optimizer.cc similarity index 99% rename from src/poly/cce_optimizer.cc rename to src/poly/davinci_halide_optimizer.cc index 4033c5bb1e83f819460f885de3221e7c55b01698..563b4617d0b4ee65bafa156ae42dbac11d869e77 100644 --- a/src/poly/cce_optimizer.cc +++ b/src/poly/davinci_halide_optimizer.cc @@ -607,7 +607,7 @@ class DynamicPaddingFix : public IRMutator { std::string fm_l1_{""}; }; -Stmt OptimizeCce(const Stmt &s, bool dynamicShape = false) { +Stmt DavinciHalideOptimizer(const Stmt &s, bool dynamicShape = false) { Stmt stmt = s; if (dynamicShape) { stmt = InductionVarElinate().Run(s); diff --git a/src/poly/davinci_mgr_strategy.cc b/src/poly/davinci_mgr_strategy.cc new file mode 100644 index 0000000000000000000000000000000000000000..13581cb18f68c83a8be1082b89c3646c7f03a5d1 --- /dev/null +++ b/src/poly/davinci_mgr_strategy.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "poly/davinci_mgr_strategy.h" + +#include "poly/schedule_pass/group.h" +#include "poly/schedule_pass/tile_outer_band.h" +#include "poly/schedule_pass/memory_manager.h" +#include "poly/schedule_pass/sink_c0.h" +#include "poly/schedule_pass/sink_last_axis.h" +#include "poly/schedule_pass/reorder_invariant_set_schedule.h" +#include "poly/schedule_pass/keep_outer_band_order.h" + +#include "poly/schedule_pass/split_outer_band.h" +#include "poly/schedule_pass/transfer_stmt.h" +#include "poly/schedule_pass/reset_coincidence_of_reduce.h" +#include "poly/schedule_pass/set_all_coincidence.h" +#include "poly/schedule_pass/reschedule.h" +#include "poly/schedule_pass/reorder_inner_band.h" +#include "poly/schedule_pass/change_marknode_position.h" +#include "poly/schedule_pass/insert_node_for_allocc.h" +#include "poly/schedule_pass/label_realize_out_position.h" +#include "poly/schedule_pass/mark_fuse_op.h" +#include "poly/schedule_pass/reorder_mark_nodes.h" +#include "poly/schedule_pass/compute_transfer_copyin.h" +#include "poly/schedule_pass/compute_inner_band_dependency.h" +#include "poly/schedule_pass/mark_outer_most.h" + +namespace akg { +namespace ir { +namespace poly { + +void DavinciMgrStrategy::RegisterTilingPasses() { + RegisterPass(std::make_shared(pass_info_, scop_info_)); +} + +void DavinciMgrStrategy::RegisterMemPromPasses() { RegisterPass(std::make_shared(scop_info_)); } + +void DavinciMgrStrategy::RegisterPasses() { + passes_.clear(); + RegisterNormalizationPasses(); + if (!scop_info_.user_config_.GetDisableGroup()) { + RegisterPass(std::make_shared(pass_info_)); + } + RegisterSchedulingPasses(); + RegisterPass(std::make_shared(pass_info_)); + if (scop_info_.user_config_.GetReorderSchedule()) { + RegisterPass(std::make_shared()); + } + if (scop_info_.user_config_.GetSinkLastAxis()) { + RegisterPass(std::make_shared(pass_info_)); + } + if (scop_info_.user_config_.GetKeepOuterBandOrder()) { + RegisterPass(std::make_shared(scop_info_)); + } + RegisterPass(std::make_shared(pass_info_)); + if (scop_info_.user_config_.GetOuterBandNeedSplit() && !scop_info_.cube_info_.IsSpecGemm()) { + RegisterPass(std::make_shared()); + } + RegisterPass(std::make_shared(scop_info_)); + if (!scop_info_.cube_info_.IsSpecGemm() && (scop_info_.cube_info_.IsConv() || scop_info_.cube_info_.IsGemm())) { + RegisterPass(std::make_shared(scop_info_, pass_info_)); + } + if (scop_info_.user_config_.GetIsTuning()) { + return; + } + RegisterTilingPasses(); + RegisterPass(std::make_shared(pass_info_)); + RegisterPass(std::make_shared(scop_info_, pass_info_)); + if (scop_info_.user_config_.GetPragmaSetAllCoincident()) { + RegisterPass(std::make_shared()); + } + if (!scop_info_.user_config_.GetIsDynamic() || !scop_info_.cube_info_.IsConv()) { + RegisterPass(std::make_shared(scop_info_, pass_info_)); + } + RegisterPass(std::make_shared(scop_info_.analysis_result_.GetCondVarsMap())); + RegisterPass(std::make_shared(scop_info_.analysis_result_.ExtractWithStmtId())); + RegisterPass(std::make_shared()); + if (scop_info_.cube_info_.IsSpecGemm() || scop_info_.cube_info_.IsGemm() || + scop_info_.cube_info_.IsConvBackpropFilter()) { + RegisterPass(std::make_shared()); + } + RegisterMemPromPasses(); + if (!scop_info_.cube_info_.IsSpecGemm()) { + RegisterPass(std::make_shared(scop_info_, pass_info_)); + } + RegisterPass(std::make_shared()); + RegisterPass(std::make_shared()); + // if coincidence constraints are disabled (due to reschedule), we cannot determine multicore axis reliably + bool can_use_multiCore = !scop_info_.cube_info_.IsSpecGemm() && scop_info_.user_config_.GetConsiderCoincidence(); + if (can_use_multiCore || scop_info_.user_config_.GetEnableMarkMultiCore()) { + RegisterPass(std::make_shared(scop_info_)); + } +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/davinci_mgr_strategy.h b/src/poly/davinci_mgr_strategy.h new file mode 100644 index 0000000000000000000000000000000000000000..d015f1427d86b68075db1df0dc8f953825e64990 --- /dev/null +++ b/src/poly/davinci_mgr_strategy.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POLY_DAVINCI_MGR_STRATEGY_H_ +#define POLY_DAVINCI_MGR_STRATEGY_H_ + +#include "poly/pass_mgr_strategy.h" + +namespace akg { +namespace ir { +namespace poly { +class DavinciMgrStrategy : public PassMgrStrategy { + public: + explicit DavinciMgrStrategy(ScopInfo &scop_info) : PassMgrStrategy(scop_info) { + pass_info_.coincident_ = scop_info_.user_config_.GetConsiderCoincidence(); + } + ~DavinciMgrStrategy() override = default; + + void RegisterTilingPasses() override; + void RegisterMemPromPasses() override; + void RegisterPasses() override; +}; + +} // namespace poly +} // namespace ir +} // namespace akg +#endif // POLY_DAVINCI_MGR_STRATEGY_H_ diff --git a/src/poly/dma_dataflow.cc b/src/poly/dma_dataflow.cc index 5f35786003ff36a7051798a3cb3ce876ab22ac4c..36290c775617e87702eee193616142b185e30e7f 100644 --- a/src/poly/dma_dataflow.cc +++ b/src/poly/dma_dataflow.cc @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "poly/dma_dataflow.h" -#include "poly/scop.h" +#include "poly/dma_dataflow.h" +#include "poly/poly_util.h" namespace akg { namespace ir { @@ -193,7 +193,6 @@ void StmtDataFlowInfo::AddWriteTensor(const std::string &name, TENSOR_DATAFLOW_T void StmtDataFlowInfo::CreateTensorDataFlow(TENSOR_DATAFLOW_TYPE type, const std::string &name, TensorDataFlow &dataflow) { CHECK_NE(name, ""); - dataflow.tensor_name_ = name; switch (type) { case TENSOR_DATAFLOW_TYPE::CUBE_CONV_A: CubeConvA(name, dataflow); diff --git a/src/poly/dma_dataflow.h b/src/poly/dma_dataflow.h index fc844387d388dc2d7572bc7e6b11eefcc0d20dd7..a5d1f52d6d8b1e3194a99f58fd8d26f7411471bf 100644 --- a/src/poly/dma_dataflow.h +++ b/src/poly/dma_dataflow.h @@ -33,7 +33,7 @@ namespace akg { namespace ir { namespace poly { class TensorFootprintCluster; -class TensorDataFlow; +struct TensorDataFlow; class StmtDataFlowInfo; enum MemType { DDR = 1, L1_, UB_, L0A_, L0B_, L0C_, UBL0_, UBL1_ }; @@ -142,7 +142,6 @@ enum TENSOR_DATAFLOW_TYPE { }; struct TensorDataFlow { - std::string tensor_name_; std::vector name_flow_; MemFlow mem_type_flow_; diff --git a/src/poly/dma_inject.cc b/src/poly/dma_inject.cc index be471f12c572c230c34dde45ce605ca62cc400d3..1cdf836ce9de9d9d25fcd3c3398b590fd59cf687 100644 --- a/src/poly/dma_inject.cc +++ b/src/poly/dma_inject.cc @@ -13,15 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "poly/dma_inject.h" - -#include -#include -#include -#include "poly/scop.h" -#include "poly/transform.h" +#include "poly/dma_inject.h" #include "poly/scop_builder.h" +#include "poly/schedule_pass.h" namespace akg { namespace ir { @@ -417,9 +412,10 @@ std::unique_ptr TensorFootprintCluster::ClusteringFootpr * 1. find n_dim & shape from binds based on tensor_id * 2. if found, update n_dim & shape from buf_def based on tensor_id * */ -void TensorShapeInfo(const Scop &scop, const isl::id &tensor_id, size_t &n_dim, Array &shape) { +void TensorShapeInfo(const ScopInfo &scop_info, const isl::id &tensor_id, size_t &n_dim, Array &shape) { n_dim = 0; - for (const auto &i : scop.binds_) { + auto binds = scop_info.user_config_.GetBind(); + for (const auto &i : binds) { if (i.first->op->name == tensor_id.get_name()) { n_dim = i.first.ndim(); shape = i.first->shape; @@ -427,7 +423,7 @@ void TensorShapeInfo(const Scop &scop, const isl::id &tensor_id, size_t &n_dim, } if (!n_dim) { - auto buf_def = scop.GetBufferDefInfo(tensor_id); + auto buf_def = scop_info.analysis_result_.GetBufferDefInfo(tensor_id); n_dim = buf_def.sizes.size(); for (auto i : buf_def.sizes) { shape.push_back(Expr(i)); @@ -435,11 +431,11 @@ void TensorShapeInfo(const Scop &scop, const isl::id &tensor_id, size_t &n_dim, } } -isl::set CollectTensorSet(const Scop &scop, const isl::id &tensor_id) { - auto space = scop.schedule_.get_domain().get_space(); +isl::set CollectTensorSet(const ScopInfo &scop_info, const isl::id &tensor_id, const isl::space &space) { + // auto space = scop.schedule_.get_domain().get_space(); size_t n_dim; Array shape; - TensorShapeInfo(scop, tensor_id, n_dim, shape); + TensorShapeInfo(scop_info, tensor_id, n_dim, shape); auto coordinate = CollectTensorCoordinate(space, tensor_id, n_dim); auto tensor_set = isl::set::universe(coordinate.get_space()); @@ -979,7 +975,7 @@ void AffineRefGroupConstructor::create() { } } -std::unique_ptr AffineRefGroupConstructor::ConstructRefGroup(Scop &scop, +std::unique_ptr AffineRefGroupConstructor::ConstructRefGroup(ScopInfo &scop_info, const isl::union_map &accesses, const isl::union_set &domain, const isl::union_map &schedule, @@ -987,7 +983,7 @@ std::unique_ptr AffineRefGroupConstructor::ConstructRefG for (auto a : accesses.get_map_list()) { auto tensor_id = a.get_tuple_id(isl_dim_out); // filter out tensor - if (affine_->NotNeedConstruct(tensor_id.get_name(), scop)) { + if (affine_->NotNeedConstruct(tensor_id.get_name(), scop_info)) { continue; } @@ -1051,7 +1047,7 @@ std::unique_ptr AffineRefGroupConstructor::AffineMapFoot return tensorGroup; } -std::unique_ptr ConstructAffineFpCluster(Scop &scop, const isl::union_map &accesses, +std::unique_ptr ConstructAffineFpCluster(ScopInfo &scop_info, const isl::union_map &accesses, const isl::union_set &domain, const isl::union_map &schedule, ReferenceType type, AffineType affine_type, AffineTensor right_matrix) { @@ -1081,32 +1077,32 @@ std::unique_ptr ConstructAffineFpCluster(Scop &scop, con case AffineType::AFFINE_IM2COL: { auto affine = static_cast(constructor.affine_); if (affine != nullptr) { - affine->attrInfo_ = scop.attr_info_; + affine->attrInfo_ = scop_info.cube_info_.GetConvAttrInfo(); } } break; case AffineType::AFFINE_WEIGHTTRANS: { auto affine = static_cast(constructor.affine_); if (affine != nullptr) { - affine->attrInfo_ = scop.attr_info_; + affine->attrInfo_ = scop_info.cube_info_.GetConvAttrInfo(); } } break; case AffineType::AFFINE_FRACTAL: { auto affine = static_cast(constructor.affine_); if (affine != nullptr) { - affine->attrInfo_ = scop.attr_info_; + affine->attrInfo_ = scop_info.cube_info_.GetConvAttrInfo(); } } break; default: break; } - return constructor.ConstructRefGroup(scop, accesses, domain, schedule, type); + return constructor.ConstructRefGroup(scop_info, accesses, domain, schedule, type); } -void AddAllBufferFootprintOfTensor(const Scop &scop, const isl::id &tensor_id, +void AddAllBufferFootprintOfTensor(const ScopInfo &scop_info, const isl::id &tensor_id, std::unordered_set &buffered_tensors) { buffered_tensors.insert(tensor_id); - for (const auto &info : scop.buffer_def_infos_) { + for (const auto &info : scop_info.analysis_result_.buffer_def_infos_) { if (info.dst_tensor_id == tensor_id) { buffered_tensors.insert(info.ancester_tensor_id); } @@ -1132,15 +1128,15 @@ std::unordered_set GatherStatementsInSubtree(const i return statements; } -bool IsExtensionUsedInSubTree(const Scop &scop, const isl::schedule_node &tree, const isl::union_map &extension, - const isl::union_map &accesses) { +bool IsExtensionUsedInSubTree(const ScopInfo &scop_info, const isl::schedule_node &tree, + const isl::union_map &extension, const isl::union_map &accesses) { auto statements = GatherStatementsInSubtree(tree); std::unordered_set promoted_tensors; extension.foreach_map([&](const isl::map &footprint) -> void { if (!footprint.range().is_wrapping()) return; const isl::id &tensor_id = footprint.range().unwrap().domain().unwrap().get_tuple_id(isl_dim_out); - AddAllBufferFootprintOfTensor(scop, tensor_id, promoted_tensors); + AddAllBufferFootprintOfTensor(scop_info, tensor_id, promoted_tensors); }); bool found_extension_in_subtree = false; @@ -1172,22 +1168,22 @@ isl::schedule_node InsertExtensionHere(isl::schedule_node &tree, const isl::sche * If they have the same range, then they will be in a same tile, and the footprint can be reused. * Otherwise, a new extension needs to be inserted. */ -isl::schedule_node InsertExtensionToFirstAccessedFilters(const Scop &scop, isl::schedule_node &tree, +isl::schedule_node InsertExtensionToFirstAccessedFilters(const ScopInfo &scop_info, isl::schedule_node &tree, const isl::union_map &extension, const isl::schedule_node &graft, isl_bool before, bool &found_extension_in_schedule) { found_extension_in_schedule = false; - if (scop.IsConv() || !tree.isa()) { + if (scop_info.cube_info_.IsConv() || !tree.isa()) { return tree; } - isl::union_map accesses = scop.data_.reads.unite(scop.data_.writes); + isl::union_map accesses = scop_info.analysis_result_.GetReads().unite(scop_info.analysis_result_.GetWrites()); isl::union_set last_schedule_range; unsigned int n_children = tree.n_children(); for (unsigned int i = 0; i < n_children; ++i) { unsigned int child_idx = before ? i : n_children - 1 - i; - if (IsExtensionUsedInSubTree(scop, tree.get_child(child_idx), extension, accesses)) { + if (IsExtensionUsedInSubTree(scop_info, tree.get_child(child_idx), extension, accesses)) { tree = tree.child(child_idx).child(0); bool insert_here = false; @@ -1266,7 +1262,7 @@ isl::schedule_node DefaultInsertExtension(isl::schedule_node tree, const isl::sc * - filter: S_1[i0] * ... (original schedule) */ -isl::schedule_node InsertExtensionBeforeOrAfter(const Scop &scop, isl::schedule_node tree, +isl::schedule_node InsertExtensionBeforeOrAfter(const ScopInfo &scop_info, isl::schedule_node tree, const isl::union_map &extension, const isl::multi_union_pw_aff &schedule, isl_bool before) { if (tree.isa() && tree.parent().isa()) { @@ -1307,7 +1303,7 @@ isl::schedule_node InsertExtensionBeforeOrAfter(const Scop &scop, isl::schedule_ } bool found_extension_in_schedule = false; - tree = InsertExtensionToFirstAccessedFilters(scop, tree, extension, graft, before, found_extension_in_schedule); + tree = InsertExtensionToFirstAccessedFilters(scop_info, tree, extension, graft, before, found_extension_in_schedule); if (found_extension_in_schedule) { return tree; @@ -1316,47 +1312,9 @@ isl::schedule_node InsertExtensionBeforeOrAfter(const Scop &scop, isl::schedule_ } } -static std::string MemTypeToString(const MemType &memType) { - switch (memType) { - case MemType::UB_: - return "UB"; - case MemType::L1_: - return "L1"; - case MemType::UBL0_: - return "UBL0"; - case MemType::UBL1_: - return "UBL1"; - case MemType::L0A_: - return "L0A"; - case MemType::L0B_: - return "L0B"; - case MemType::L0C_: - return "L0C"; - case MemType::DDR: - return "GM"; - default: - return ""; - } -} - -static std::string GetIslReadName(Scop &scop, const isl::id &cluster_id) { - auto tensor_info = scop.GetBufferDefInfo(cluster_id); - MemType memType = tensor_info.SrcMemType(); - return MemTypeToString(memType) + "read"; -} - -static std::string GetIslWriteName(Scop &scop, const isl::id &cluster_id) { - if (scop.HasBufferDefInfo(cluster_id)) { - auto tensor_info = scop.GetBufferDefInfo(cluster_id); - MemType memType = tensor_info.DstMemType(); - return MemTypeToString(memType) + "write"; - } - return MemTypeToString(MemType::DDR) + "write"; -} - -isl::schedule_node PlaceIm2colBelowImpl(Scop &scop, isl::schedule_node tree, const TensorFootprintCluster &cluster, - const isl::map &footprint, const isl::set &original_elements, - const isl::set &read_set) { +isl::schedule_node PlaceIm2colBelowImpl(ScopInfo &scop_info, isl::schedule_node tree, + const TensorFootprintCluster &cluster, const isl::map &footprint, + const isl::set &original_elements, const isl::set &read_set) { bool reads = (!cluster.RichReadRelations().is_empty() && cluster.ReadNeedDma()); if (reads) { auto cluster_id = footprint.get_tuple_id(isl_dim_out); @@ -1367,20 +1325,47 @@ isl::schedule_node PlaceIm2colBelowImpl(Scop &scop, isl::schedule_node tree, con .wrap() .product(buffered_footprint); auto fp_space_identity = isl::multi_aff::identity(footprint.get_space().range().map_from_set()); - auto buffer_def = scop.GetBufferDefInfo(cluster_id); + auto buffer_def = scop_info.analysis_result_.GetBufferDefInfo(cluster_id); fp_space_identity = RemoveDimensionOfSizeOne(fp_space_identity, buffer_def.TensorSize(tree.parent())); auto extension_map = footprint.wrap().identity().domain_factor_domain().domain_factor_domain(); - isl::id read_id = isl::id(tree.ctx(), GetIslReadName(scop, cluster_id)); + isl::id read_id = isl::id(tree.ctx(), scop_info.GetIslReadName(cluster_id)); auto read_extension = extension_map.intersect_range(buffered_read).set_tuple_id(isl_dim_out, read_id); auto read_mupa = isl::multi_union_pw_aff(fp_space_identity.pullback( isl::multi_aff::wrapped_range_map(footprint.get_space().wrap().set_set_tuple_id(read_id)))); - tree = InsertExtensionBeforeOrAfter(scop, tree.get_child(0), read_extension, read_mupa, isl_bool_true); + tree = InsertExtensionBeforeOrAfter(scop_info, tree.get_child(0), read_extension, read_mupa, isl_bool_true); } - scop.schedule_ = tree.get_schedule(); return tree; } -void UpdateTensorShape(Scop &scop, const isl::map &read_extension) { +/* + * Update sizes of a specific tensor in order to support realize shape expansion in UB -> L1 strided copy + * param new_sizes: new shape of the tensor + * return: found or not found + */ +bool UpdateBufferDefInfoSizes(ScopInfo &info, const isl::id &tensor_id, const std::vector &new_sizes) { + for (auto &buffer_def_info : info.analysis_result_.buffer_def_infos_) { + // update the first occurrence + if (buffer_def_info.dst_tensor_id == tensor_id) { + auto old_sizes = buffer_def_info.sizes; + CHECK(old_sizes.size() == new_sizes.size()); + Array shapes; + for (size_t dim = 0; dim < new_sizes.size(); ++dim) { + size_t new_size = std::max(new_sizes[dim], old_sizes[dim]); + shapes.push_back(Expr(static_cast(new_size))); + } + Tensor tensor = placeholder(shapes, buffer_def_info.data_type, tensor_id.get_name()); + const Buffer buffer = decl_buffer(shapes, buffer_def_info.data_type, tensor_id.get_name()); + info.user_config_.SetBind(tensor, buffer); + + buffer_def_info.sizes = new_sizes; + buffer_def_info.tensor = tensor; + return true; + } + } + return false; +} + +void UpdateTensorShape(ScopInfo &scop_info, const isl::map &read_extension) { ScopedFootprint foot_print = ComputeFootprintOfRange(read_extension.domain_factor_domain()); if (!foot_print.box.is_valid()) { return; @@ -1391,19 +1376,19 @@ void UpdateTensorShape(Scop &scop, const isl::map &read_extension) { for (const auto &size : foot_print.box.get_size().get_val_list()) { shape.push_back(size.get_num_si()); } - static_cast(scop.UpdateBufferDefInfoSizes(cluster_id, shape)); + static_cast(UpdateBufferDefInfoSizes(scop_info, cluster_id, shape)); } -isl::schedule_node InsertStmtExtension(Scop &scop, isl::schedule_node tree, isl::map read, isl::map read_extension, - const isl::union_map &raw_reads, const isl::union_map &raw_writes, - const isl::union_map &raw_copyin, const isl::union_map &schedule, - BufferDefInfo &def) { +isl::schedule_node InsertStmtExtension(ScopInfo &scop_info, isl::schedule_node tree, isl::map read, + isl::map read_extension, const isl::union_map &raw_reads, + const isl::union_map &raw_writes, const isl::union_map &raw_copyin, + const isl::union_map &schedule, BufferDefInfo &def) { isl::union_map reads = isl::union_map(read); isl::union_map writes = raw_writes.intersect_range(reads.range()); isl::union_map dependence = DependenceAnalysis(writes, reads, writes, schedule); isl::union_set stmt = dependence.domain().universe(); writes = raw_writes.intersect_domain(stmt); - UpdateTensorShape(scop, read_extension); + UpdateTensorShape(scop_info, read_extension); /* get stmt extension */ isl::union_map stmt_ext = isl::union_map(read_extension); @@ -1427,7 +1412,7 @@ isl::schedule_node InsertStmtExtension(Scop &scop, isl::schedule_node tree, isl: identity_copy_schedule = RemoveDimensionOfSizeOne(identity_copy_schedule, def.TensorSize(tree.parent())); isl::multi_union_pw_aff stmtSchedule = isl::multi_union_pw_aff(identity_copy_schedule); /* insert extension node */ - tree = InsertExtensionBeforeOrAfter(scop, tree.get_child(0), stmt_extension, stmtSchedule, isl_bool_true); + tree = InsertExtensionBeforeOrAfter(scop_info, tree.get_child(0), stmt_extension, stmtSchedule, isl_bool_true); } /* next */ @@ -1443,7 +1428,8 @@ isl::schedule_node InsertStmtExtension(Scop &scop, isl::schedule_node tree, isl: read = readList.get_at(i); readExt = readExt.intersect_range(isl::union_set(read.range())); read_extension = isl::map::from(readExt); - tree = InsertStmtExtension(scop, tree, read, read_extension, raw_reads, raw_writes, raw_copyin, schedule, def); + tree = + InsertStmtExtension(scop_info, tree, read, read_extension, raw_reads, raw_writes, raw_copyin, schedule, def); } } return tree; @@ -1464,49 +1450,51 @@ void CheckOutOfBoundAccess(const isl::map &access_elements, const isl::set &orig } } -void PlaceDataCopyBelowImplReadWrite(Scop &scop, isl::schedule_node &tree, const TensorFootprintCluster &cluster, - const isl::map &footprint, const isl::id &tensor_id, - const isl::set &original_elements, const isl::map &exact_writes, - isl::map &read_extension, isl::set &buffered_footprint, const isl::id &cluster_id, - isl::map &extension_map, isl::id &read_id) { +void PlaceDataCopyBelowImplReadWrite(ScopInfo &scop_info, isl::schedule_node &tree, + const TensorFootprintCluster &cluster, const isl::map &footprint, + const isl::id &tensor_id, const isl::set &original_elements, + const isl::map &exact_writes, isl::map &read_extension, + isl::set &buffered_footprint, const isl::id &cluster_id, isl::map &extension_map, + isl::id &read_id) { bool reads = (!cluster.RichReadRelations().is_empty() && cluster.ReadNeedDma()); bool writes = (!cluster.RichWriteRelations().is_empty() && cluster.WriteNeedDma()); if (writes) { - auto tensor_info = scop.GetBufferDefInfo(cluster_id); + auto tensor_info = scop_info.analysis_result_.GetBufferDefInfo(cluster_id); if (MemType::UBL0_ == tensor_info.DstMemType() || MemType::UB_ == tensor_info.DstMemType() || tensor_info.IsPreCubeL1Write()) { - if (!scop.IsInBinds(tensor_id)) writes = false; + if (!scop_info.IsInBinds(tensor_id)) writes = false; } if (tensor_info.IsPreCubeL1Write()) { - if (!scop.IsInBinds(tensor_id)) reads = false; + if (!scop_info.IsInBinds(tensor_id)) reads = false; } } auto fp_space_identity = isl::multi_aff::identity(footprint.get_space().range().map_from_set()); - auto buffer_def = scop.GetBufferDefInfo(cluster_id); + auto buffer_def = scop_info.analysis_result_.GetBufferDefInfo(cluster_id); fp_space_identity = RemoveDimensionOfSizeOne(fp_space_identity, buffer_def.TensorSize(tree.parent())); if (reads) { auto read_mupa = isl::multi_union_pw_aff(fp_space_identity.pullback( isl::multi_aff::wrapped_range_map(footprint.get_space().wrap().set_set_tuple_id(read_id)))); - tree = InsertExtensionBeforeOrAfter(scop, tree.get_child(0), read_extension, read_mupa, isl_bool_true); + tree = InsertExtensionBeforeOrAfter(scop_info, tree.get_child(0), read_extension, read_mupa, isl_bool_true); } if (writes) { isl::schedule_node tree_write = tree.get_child(0); - if (scop.params_.empty() && scop.IsLoad3dL1Ub()) { + if (scop_info.user_config_.GetParams().empty() && scop_info.cube_info_.IsLoad3dL1Ub()) { tree_write = tree; } isl::set writes_set = exact_writes.intersect_range(original_elements).wrap().product(buffered_footprint); - isl::id write_id = isl::id(tree.ctx(), GetIslWriteName(scop, tensor_id)); + isl::id write_id = isl::id(tree.ctx(), scop_info.GetIslWriteName(tensor_id)); isl::map write_extension = extension_map.intersect_range(writes_set).set_tuple_id(isl_dim_out, write_id); auto write_mupa = isl::multi_union_pw_aff(fp_space_identity.pullback( isl::multi_aff::wrapped_range_map(footprint.get_space().wrap().set_set_tuple_id(write_id)))); - tree = InsertExtensionBeforeOrAfter(scop, tree_write, write_extension, write_mupa, isl_bool_false); + tree = InsertExtensionBeforeOrAfter(scop_info, tree_write, write_extension, write_mupa, isl_bool_false); } } -void PlaceDataCopyBelowImplFakeReads(Scop &scop, isl::schedule_node &tree, const TensorFootprintCluster &cluster, - isl::map &read_extension, const isl::id &cluster_id) { - auto buffer_def = scop.GetBufferDefInfo(cluster_id); +void PlaceDataCopyBelowImplFakeReads(ScopInfo &scop_info, isl::schedule_node &tree, + const TensorFootprintCluster &cluster, isl::map &read_extension, + const isl::id &cluster_id, const isl::union_map &sched) { + auto buffer_def = scop_info.analysis_result_.GetBufferDefInfo(cluster_id); bool fake_reads = (!cluster.RichReadRelations().is_empty() && cluster.ReadNeedDma() && cluster.ReadNeedExtension()); if (fake_reads) { isl::schedule_node node = tree; @@ -1526,39 +1514,44 @@ void PlaceDataCopyBelowImplFakeReads(Scop &scop, isl::schedule_node &tree, const stmt_extension = stmt_extension.set_tuple_id(isl_dim_out, stmt_tensor_id); isl::union_set readTensor = isl::union_set(stmt_extension.range()); - isl::union_map reads_map = scop.data_.fake_copyin.domain_factor_domain().intersect_range(readTensor.universe()); + isl::union_map reads_map = + scop_info.analysis_result_.GetFakeCopyin().domain_factor_domain().intersect_range(readTensor.universe()); if (!reads_map.is_empty()) { - isl::union_map raw_reads = scop.data_.reads.domain_factor_domain(); - isl::union_map raw_writes = scop.data_.writes.domain_factor_domain(); - isl::union_map raw_copyin = scop.data_.copyin.domain_factor_domain(); + isl::union_map raw_reads = scop_info.analysis_result_.GetReads().domain_factor_domain(); + isl::union_map raw_writes = scop_info.analysis_result_.GetWrites().domain_factor_domain(); + isl::union_map raw_copyin = scop_info.analysis_result_.GetCopyin().domain_factor_domain(); isl::map_list readList = reads_map.get_map_list(); int n = readList.size(); for (int i = 0; i < n; ++i) { - tree = InsertStmtExtension(scop, tree, readList.get_at(i), stmt_extension, raw_reads, raw_writes, raw_copyin, - scop.sch_, buffer_def); + tree = InsertStmtExtension(scop_info, tree, readList.get_at(i), stmt_extension, raw_reads, raw_writes, + raw_copyin, sched, buffer_def); } } } } } -isl::schedule_node PlaceDataCopyBelowImpl(Scop &scop, isl::schedule_node tree, const TensorFootprintCluster &cluster, - const isl::map &footprint, const isl::id &tensor_id, - const isl::set &original_elements, const isl::map &exact_reads, - const isl::map &exact_writes) { +isl::schedule_node PlaceDataCopyBelowImpl(ScopInfo &scop_info, isl::schedule_node tree, + const TensorFootprintCluster &cluster, const isl::map &footprint, + const isl::id &tensor_id, const isl::set &original_elements, + const isl::map &exact_reads, const isl::map &exact_writes, + const isl::union_map &sch) { auto cluster_id = footprint.get_tuple_id(isl_dim_out); - if (!scop.IsConv()) CheckOutOfBoundAccess(exact_reads, original_elements, "read"); + if (!scop_info.cube_info_.IsConv()) CheckOutOfBoundAccess(exact_reads, original_elements, "read"); bool special_dma = false; - if (scop.conv_special_dma_ || (scop.attr_info_.count(ATTR_CONV_SPECIAL_DMA) > 0)) { - if (scop.attr_info_.count(ATTR_CONV_BACKPROP_FILTER) > 0 && scop.attr_info_.count(ATTR_CONV_KERNEL_H) > 0 && - scop.attr_info_.count(ATTR_CONV_KERNEL_W) > 0 && scop.attr_info_.count(ATTR_CONV_FEATURE_C) > 0) { - std::string featureName = scop.ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME) + "_local_L1"; - int kh = scop.ExtractIntFromAttrs(ATTR_CONV_KERNEL_H); - int kw = scop.ExtractIntFromAttrs(ATTR_CONV_KERNEL_W); - int ci = scop.ExtractIntFromAttrs(ATTR_CONV_FEATURE_C); + if (scop_info.user_config_.GetConvSpecialDma() || + (scop_info.cube_info_.GetConvAttrInfo().count(ATTR_CONV_SPECIAL_DMA) > 0)) { + if (scop_info.cube_info_.GetConvAttrInfo().count(ATTR_CONV_BACKPROP_FILTER) > 0 && + scop_info.cube_info_.GetConvAttrInfo().count(ATTR_CONV_KERNEL_H) > 0 && + scop_info.cube_info_.GetConvAttrInfo().count(ATTR_CONV_KERNEL_W) > 0 && + scop_info.cube_info_.GetConvAttrInfo().count(ATTR_CONV_FEATURE_C) > 0) { + std::string featureName = scop_info.cube_info_.ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME) + "_local_L1"; + int kh = scop_info.cube_info_.ExtractIntFromAttrs(ATTR_CONV_KERNEL_H); + int kw = scop_info.cube_info_.ExtractIntFromAttrs(ATTR_CONV_KERNEL_W); + int ci = scop_info.cube_info_.ExtractIntFromAttrs(ATTR_CONV_FEATURE_C); if (featureName == cluster_id.get_name() && kh == 7 && kw == 7 && ci == 16) { special_dma = true; } @@ -1576,7 +1569,7 @@ isl::schedule_node PlaceDataCopyBelowImpl(Scop &scop, isl::schedule_node tree, c read_set = read_set.product(buffered_footprint); isl::map extension_map = footprint.wrap().identity().domain_factor_domain().domain_factor_domain(); - isl::id read_id = isl::id(tree.ctx(), GetIslReadName(scop, cluster_id)); + isl::id read_id = isl::id(tree.ctx(), scop_info.GetIslReadName(cluster_id)); isl::map read_extension = extension_map.intersect_range(read_set).set_tuple_id(isl_dim_out, read_id); if (special_dma) { isl::map read_set_map = read_extension.range().unwrap(); @@ -1587,21 +1580,21 @@ isl::schedule_node PlaceDataCopyBelowImpl(Scop &scop, isl::schedule_node tree, c read_extension = read_set_map.wrap().identity().domain_factor_domain().domain_factor_domain().set_tuple_id(isl_dim_out, read_id); } - if (!scop.IsConv()) CheckOutOfBoundAccess(exact_writes, original_elements, "write"); + if (!scop_info.cube_info_.IsConv()) CheckOutOfBoundAccess(exact_writes, original_elements, "write"); - PlaceDataCopyBelowImplReadWrite(scop, tree, cluster, footprint, tensor_id, original_elements, exact_writes, + PlaceDataCopyBelowImplReadWrite(scop_info, tree, cluster, footprint, tensor_id, original_elements, exact_writes, read_extension, buffered_footprint, cluster_id, extension_map, read_id); - PlaceDataCopyBelowImplFakeReads(scop, tree, cluster, read_extension, cluster_id); + PlaceDataCopyBelowImplFakeReads(scop_info, tree, cluster, read_extension, cluster_id, sch); - scop.schedule_ = tree.get_schedule(); return tree; } -isl::schedule_node PlaceInnerDataCopyBelow(Scop &scop, const isl::schedule_node &tree, +isl::schedule_node PlaceInnerDataCopyBelow(ScopInfo &scop_info, const isl::schedule_node &tree, const TensorFootprintCluster &cluster, const TensorFootprintCluster &outer_scope_cluster, const isl::id &tensor_id, - const isl::id &cluster_id, const isl::id &outer_scope_cluster_id) { + const isl::id &cluster_id, const isl::id &outer_scope_cluster_id, + const isl::union_map &sch) { // map :: [S -> O] -> P_inner isl::map inner_scope_footprint = isl::map(cluster.ComputeBufferedFootprints()).set_tuple_id(isl_dim_out, cluster_id); @@ -1634,12 +1627,13 @@ isl::schedule_node PlaceInnerDataCopyBelow(Scop &scop, const isl::schedule_node inner_scope_footprint = inner_scope_footprint.apply_domain(outer_scope_footprint); - return PlaceDataCopyBelowImpl(scop, tree, cluster, inner_scope_footprint, tensor_id, outerScopeGroupFootprint, + return PlaceDataCopyBelowImpl(scop_info, tree, cluster, inner_scope_footprint, tensor_id, outerScopeGroupFootprint, cluster.RichReadRelations().wrap().apply(outer_scope_footprint).unwrap(), - cluster.RichWriteRelations().wrap().apply(outer_scope_footprint).unwrap()); + cluster.RichWriteRelations().wrap().apply(outer_scope_footprint).unwrap(), sch); } -isl::schedule_node PlaceIm2colBelow(Scop &scop, const isl::schedule_node &tree, const TensorFootprintCluster &cluster, +isl::schedule_node PlaceIm2colBelow(ScopInfo &scop_info, const isl::schedule_node &tree, + const TensorFootprintCluster &cluster, const TensorFootprintCluster &outer_scope_cluster, const isl::id &cluster_id, const isl::id &outer_scope_cluster_id) { // map :: [S -> O] -> P_inner @@ -1671,24 +1665,25 @@ isl::schedule_node PlaceIm2colBelow(Scop &scop, const isl::schedule_node &tree, // map :: [S -> P_outer] -> P_inner inner_scope_footprint = inner_scope_footprint.apply_domain(outer_scope_footprint); - return PlaceIm2colBelowImpl(scop, tree, cluster, inner_scope_footprint, + return PlaceIm2colBelowImpl(scop_info, tree, cluster, inner_scope_footprint, outer_scope_cluster.BufferedFootprint().set_tuple_id(outer_scope_cluster_id), outer_scope_cluster.BufferedFootprint().set_tuple_id(outer_scope_cluster_id)); } -isl::schedule_node PlaceOuterDataCopyBelow(Scop &scop, const isl::schedule_node &tree, +isl::schedule_node PlaceOuterDataCopyBelow(ScopInfo &scop_info, const isl::schedule_node &tree, const TensorFootprintCluster &cluster, const isl::id &tensor_id, - const isl::id &cluster_id) { + const isl::id &cluster_id, const isl::union_map &sch, + const isl::space &sch_space) { CHECK(!cluster_id.is_null()) << "expected cluster id"; - auto tensor_elements = CollectTensorSet(scop, tensor_id); + auto tensor_elements = CollectTensorSet(scop_info, tensor_id, sch_space); isl::map footprint; if (cluster.foot_print_.box.is_valid()) { footprint = isl::map(cluster.ComputeBufferedFootprints()).set_tuple_id(isl_dim_out, cluster_id); } else { footprint = isl::map(cluster.IdentityBufferFootprint()).set_tuple_id(isl_dim_out, cluster_id); } - return PlaceDataCopyBelowImpl(scop, tree, cluster, footprint, tensor_id, tensor_elements, cluster.RichReadRelations(), - cluster.RichWriteRelations()); + return PlaceDataCopyBelowImpl(scop_info, tree, cluster, footprint, tensor_id, tensor_elements, + cluster.RichReadRelations(), cluster.RichWriteRelations(), sch); } void UniteInterleavedReadsAndWrites(std::vector> &clusters) { diff --git a/src/poly/dma_inject.h b/src/poly/dma_inject.h index eccc3657e241881db40b95a3c05a612aa501e960..0c61618d67d1ee653f4d1d6ac35bde8a1a6e4978 100644 --- a/src/poly/dma_inject.h +++ b/src/poly/dma_inject.h @@ -13,20 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #ifndef POLY_DMA_INJECT_H_ #define POLY_DMA_INJECT_H_ -#pragma once #include -#include -#include -#include -#include -#include -#include #include #include "poly/isl.h" -#include "poly/scop.h" +#include "poly/scop_info.h" namespace akg { namespace ir { @@ -177,30 +171,36 @@ std::vector ExpandInvalidDims(const std::vector &invalid_dims, const i int &first_invalid_domain_dim); isl::multi_aff ComputeBufferFootprint(const isl::map &access, const ScopedFootprint &foot_print); -isl::schedule_node PlaceDataCopyBelowImpl(Scop &scop, isl::schedule_node tree, const TensorFootprintCluster &cluster, - const isl::map &buffer_footprint, const isl::id &tensor_id, - const isl::set &original_elements, const isl::map &exact_reads, - const isl::map &exact_writes); +isl::schedule_node PlaceDataCopyBelowImpl(ScopInfo &scop_info, isl::schedule_node tree, + const TensorFootprintCluster &cluster, const isl::map &buffer_footprint, + const isl::id &tensor_id, const isl::set &original_elements, + const isl::map &exact_reads, const isl::map &exact_writes, + const isl::union_map &sch); -void PlaceDataCopyBelowImplReadWrite(Scop &scop, isl::schedule_node &tree, const TensorFootprintCluster &cluster, - const isl::map &footprint, const isl::id &tensor_id, - const isl::set &original_elements, const isl::map &exact_writes, - isl::map &read_extension, isl::set &buffered_footprint, const isl::id &cluster_id, - isl::map &extension_map, isl::id &read_id); +void PlaceDataCopyBelowImplReadWrite(ScopInfo &scop_info, isl::schedule_node &tree, + const TensorFootprintCluster &cluster, const isl::map &footprint, + const isl::id &tensor_id, const isl::set &original_elements, + const isl::map &exact_writes, isl::map &read_extension, + isl::set &buffered_footprint, const isl::id &cluster_id, isl::map &extension_map, + isl::id &read_id); -void PlaceDataCopyBelowImplFakeReads(Scop &scop, isl::schedule_node &tree, const TensorFootprintCluster &cluster, - isl::map &read_extension, const isl::id &cluster_id); +void PlaceDataCopyBelowImplFakeReads(ScopInfo &scop_info, isl::schedule_node &tree, + const TensorFootprintCluster &cluster, isl::map &read_extension, + const isl::id &cluster_id, const isl::union_map &sch); -isl::schedule_node PlaceInnerDataCopyBelow(Scop &scop, const isl::schedule_node &tree, +isl::schedule_node PlaceInnerDataCopyBelow(ScopInfo &scop_info, const isl::schedule_node &tree, const TensorFootprintCluster &cluster, const TensorFootprintCluster &outer_scope_cluster, const isl::id &tensor_id, - const isl::id &cluster_id, const isl::id &outer_scope_cluster_id); + const isl::id &cluster_id, const isl::id &outer_scope_cluster_id, + const isl::union_map &sch); -isl::schedule_node PlaceOuterDataCopyBelow(Scop &scop, const isl::schedule_node &tree, +isl::schedule_node PlaceOuterDataCopyBelow(ScopInfo &scop_info, const isl::schedule_node &tree, const TensorFootprintCluster &cluster, const isl::id &tensor_id, - const isl::id &cluster_id); + const isl::id &cluster_id, const isl::union_map &sch, + const isl::space &sch_space); -isl::schedule_node PlaceIm2colBelow(Scop &scop, const isl::schedule_node &tree, const TensorFootprintCluster &cluster, +isl::schedule_node PlaceIm2colBelow(ScopInfo &scop_info, const isl::schedule_node &tree, + const TensorFootprintCluster &cluster, const TensorFootprintCluster &outer_scope_cluster, const isl::id &cluster_id, const isl::id &outer_scope_cluster_id); @@ -210,7 +210,7 @@ class AffineBase { public: virtual ~AffineBase() = default; virtual isl::map ConstructAffine(isl::map original) = 0; - virtual bool NotNeedConstruct(std::string name, Scop &scop) = 0; + virtual bool NotNeedConstruct(std::string name, ScopInfo &scop_info) = 0; }; class GemmInnerTransposeAffine : public AffineBase { @@ -221,13 +221,13 @@ class GemmInnerTransposeAffine : public AffineBase { isl::map ConstructAffine(isl::map original_map) final; void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; } - bool NotNeedConstruct(std::string name, Scop &scop) override { + bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override { // right matrix filter !B tensor - if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop.IsB(name)) { + if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop_info.cube_info_.IsB(name)) { return true; } // left matrix filter !A tensor - if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) { + if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) { return true; } return false; @@ -246,13 +246,13 @@ class GemmTransposeAffine : public AffineBase { void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; } - bool NotNeedConstruct(std::string name, Scop &scop) override { + bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override { // right matrix filter !B tensor - if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop.IsB(name)) { + if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop_info.cube_info_.IsB(name)) { return true; } // left matrix filter !A tensor - if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) { + if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) { return true; } return false; @@ -271,17 +271,17 @@ class GemmTransposeBlockAffine : public AffineBase { void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; } - bool NotNeedConstruct(std::string name, Scop &scop) override { + bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override { // right matrix filter !B tensor - if (AffineTensor::RIGHT_TENSOR == is_right_matrix_ && !scop.IsB(name)) { + if (AffineTensor::RIGHT_TENSOR == is_right_matrix_ && !scop_info.cube_info_.IsB(name)) { return true; } // left matrix filter !A tensor - if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) { + if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) { return true; } - if (AffineTensor::OUT_TENSOR == is_right_matrix_ && !scop.IsC(name)) { + if (AffineTensor::OUT_TENSOR == is_right_matrix_ && !scop_info.cube_info_.IsC(name)) { return true; } @@ -302,8 +302,8 @@ class Im2colAffine : public AffineBase { void ConstructAffineMap(isl::map &footprint, std::vector &v_aff_x, std::vector &v_aff_y, const isl::map &original_map, isl::local_space &ls); - bool NotNeedConstruct(std::string name, Scop &scop) override { - if (!scop.IsA(name)) { + bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override { + if (!scop_info.cube_info_.IsA(name)) { return true; } return false; @@ -319,8 +319,8 @@ class WeightAffine : public AffineBase { isl::map ConstructAffine(isl::map original_map) final; - bool NotNeedConstruct(std::string name, Scop &scop) override { - if (!scop.IsB(name)) { + bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override { + if (!scop_info.cube_info_.IsB(name)) { return true; } return false; @@ -339,8 +339,8 @@ class FractalAffine : public AffineBase { void ConstructAffineMap(isl::map &footprint, std::vector &v_aff_x, std::vector &v_aff_y, const isl::map &original_map, isl::local_space &ls); - bool NotNeedConstruct(std::string name, Scop &scop) override { - if (!scop.IsA(name)) { + bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override { + if (!scop_info.cube_info_.IsA(name)) { return true; } return false; @@ -371,7 +371,7 @@ class AffineRefGroupConstructor { void create(); - std::unique_ptr ConstructRefGroup(Scop &scop, const isl::union_map &accesses, + std::unique_ptr ConstructRefGroup(ScopInfo &scop_info, const isl::union_map &accesses, const isl::union_set &domain, const isl::union_map &schedule, ReferenceType type); @@ -391,7 +391,7 @@ class AffineRefGroupConstructor { AffineType type_ = AffineType::AFFINE_GEMM; }; -std::unique_ptr ConstructAffineFpCluster(Scop &scop, const isl::union_map &accesses, +std::unique_ptr ConstructAffineFpCluster(ScopInfo &info, const isl::union_map &accesses, const isl::union_set &domain, const isl::union_map &schedule, ReferenceType type, AffineType affine_type, diff --git a/src/poly/dump_log.cc b/src/poly/dump_log.cc index 3de299df1fd1303deda3c5ea985a60d4a93bef56..864d0a62f2cf4fe6d16c5633d972e87d953f741f 100644 --- a/src/poly/dump_log.cc +++ b/src/poly/dump_log.cc @@ -20,9 +20,10 @@ #include #include #include +#include +#include #include "poly/poly_util.h" -#include "poly/scop.h" #include "poly/dma_inject.h" namespace akg { @@ -152,6 +153,11 @@ void PrettyPrintSchTree(std::FILE *fp, const isl::schedule &sch) { } } +std::string PrettyPrintSchTree(const isl::schedule &sch) { + std::string sch_tree_str = DumpSchTreeToString(sch); + return FormatSchTreeStr(sch_tree_str); +} + /* * Check that file name is a simple relative path (does not start with "/", and does not include "." or ".."). * FileName should not include extension, and the extension will be appended to FileName. @@ -218,6 +224,7 @@ bool CompareSchTreeWithString(const std::string &compare_sch_, const isl::schedu void PrintHeader(std::ofstream &of, const std::string &str) { of << std::endl << ">>>>>>>>>> " << str << " <<<<<<<<<<" << std::endl; } +void PrintHeader(const std::string &str) { std::cout << ">>>>>>>>>> " << str << " <<<<<<<<<<" << std::endl; } void DumpNode(std::ofstream &of, const air::Node *node) { if (node->IsInstance()) { @@ -274,28 +281,28 @@ void CreateDirIfNotExist(const std::string &file_name) { free(file_name_); } -void Scop::DumpScopDataBasics(std::ofstream &of) { +void AnalysisResult::DumpScopDataBasics(std::ofstream &of) { PrintHeader(of, "statements"); - for (const auto &stmt : data_.statements) { + for (const auto &stmt : GetStatementMap()) { of << stmt.first << " : "; DumpNode(of, stmt.second); of << std::endl; } PrintHeader(of, "accesses"); - for (const auto &stmt : data_.accesses) { + for (const auto &stmt : GetAccessMap()) { of << stmt.second << " : "; DumpNode(of, stmt.first); of << std::endl; } PrintHeader(of, "domains"); - for (const auto &stmt : data_.domains) { + for (const auto &stmt : GetOperatorDomainMap()) { of << stmt.first << " : param_space " << stmt.second.param_space << std::endl; } PrintHeader(of, "stmt_op_Info"); - for (const auto &stmt : data_.stmt_op_Info) { + for (const auto &stmt : GetStmtOpInfoMap()) { of << stmt.first << " : ops [ "; for (auto op : stmt.second.ops) { of << int(op) << ", "; @@ -307,92 +314,79 @@ void Scop::DumpScopDataBasics(std::ofstream &of) { of << "]" << std::endl; } - PrintHeader(of, "iterators"); - for (const auto &it : data_.iterators) { - of << it.first << " : [ "; - for (const auto &str : it.second) { - of << str << ", "; - } - of << "]" << std::endl; - } - PrintHeader(of, "reads"); - of << FormatMupaStr(data_.reads) << std::endl; + of << FormatMupaStr(GetReads()) << std::endl; PrintHeader(of, "writes"); - of << FormatMupaStr(data_.writes) << std::endl; + of << FormatMupaStr(GetWrites()) << std::endl; PrintHeader(of, "copyin"); - of << FormatMupaStr(data_.copyin) << std::endl; + of << FormatMupaStr(GetCopyin()) << std::endl; PrintHeader(of, "fake_copyin"); - of << FormatMupaStr(data_.fake_copyin) << std::endl; + of << FormatMupaStr(GetFakeCopyin()) << std::endl; PrintHeader(of, "inter_band_dependency"); - of << FormatMupaStr(data_.inter_band_dependency) << std::endl; + of << FormatMupaStr(GetInnerBandDependency()) << std::endl; PrintHeader(of, "transfer_stmt"); - of << FormatMupaStr(data_.transfer_stmt) << std::endl; + of << FormatMupaStr(GetTransferStmt()) << std::endl; PrintHeader(of, "reduce_stmts"); - for (const auto &stmt : data_.reduce_stmts) { + for (const auto &stmt : GetReduceStmtMap()) { of << stmt.first << ": reduce axis [ "; for (const auto &axis : stmt.second) { of << axis << " "; } of << "]" << std::endl; } - - PrintHeader(of, "group_filter_map"); - for (const auto &group : group_filter_map_) { - of << group.first << " : [ "; - for (auto filter : group.second) { - of << filter << ", "; - } - of << "]" << std::endl; - } } -void Scop::DumpScopDataAdvanced(std::ofstream &of) { +void ScopInfo::DumpScopDataAdvanced(std::ofstream &of) { PrintHeader(of, "binds"); - for (auto bind : binds_) { + auto binds = user_config_.GetBind(); + for (auto bind : binds) { of << bind.first << " : " << bind.second << std::endl; } PrintHeader(of, "binds_orig"); - for (auto bind : binds_orig_) { + auto binds_orig = user_config_.GetOriginBind(); + for (auto bind : binds_orig) { of << bind.first << " : " << bind.second << std::endl; } PrintHeader(of, "realize_from_input"); - for (const auto &id : realize_from_input_) { + auto realize_from_input = user_config_.GetRealizeFromInput(); + for (const auto &id : realize_from_input) { of << id << ", "; } of << std::endl; PrintHeader(of, "dim_infos"); - for (const auto &dim_info : dim_infos_) { + for (const auto &dim_info : analysis_result_.GetTileSizes()) { of << "index=" << dim_info.index << " axis=" << dim_info.axis << " l1_tiling_size=" << dim_info.l1_tiling_size << " l0_tiling_size=" << dim_info.l0_tiling_size << " dim_seq=" << dim_info.dim_seq << std::endl; } PrintHeader(of, "fractal_int_info"); - for (const auto &info : fractal_int_info_) { + for (const auto &info : cube_info_.fractal_int_info_) { of << info.first << " : " << info.second << std::endl; } PrintHeader(of, "fractal_str_info"); - for (const auto &info : fractal_str_info_) { + for (const auto &info : cube_info_.fractal_str_info_) { of << info.first << " : " << info.second << std::endl; } PrintHeader(of, "conditional_write_buffer_footprints"); - for (const auto &tensor : conditional_write_buffer_footprints_) { + auto conditional_write_buffer_footprints = analysis_result_.GetConditionalWriteBufferFootprints(); + for (const auto &tensor : conditional_write_buffer_footprints) { of << tensor << std::endl; } PrintHeader(of, "tensor_name_flows"); - for (const auto &name_flow : tensor_name_flows_) { + auto tensor_name_flows = analysis_result_.GetTensorNameFlows(); + for (const auto &name_flow : tensor_name_flows) { of << name_flow.first << " : [ "; for (const auto &name : name_flow.second) { of << name << ", "; @@ -401,7 +395,8 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) { } PrintHeader(of, "tensor_memflows"); - for (const auto &mem_flow : tensor_mem_flows_) { + auto tensor_mem_flows = analysis_result_.GetTensorMemFlows(); + for (const auto &mem_flow : tensor_mem_flows) { of << mem_flow.first << " : [ "; for (auto mem : mem_flow.second) { of << static_cast(mem) << ", "; @@ -409,25 +404,8 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) { of << "]" << std::endl; } - PrintHeader(of, "n_clusters"); - for (const auto &cluster : n_clusters_) { - of << cluster.first << " : " << cluster.second << std::endl; - } - - PrintHeader(of, "bufferedDecls"); - for (const auto &buffered_decl : buffered_decls_) { - of << buffered_decl.first << " : " - << "tensor_id=" << buffered_decl.second.tensor_id << "type=" << buffered_decl.second.type - << "kind=" << static_cast(buffered_decl.second.kind) << "tensor=" << buffered_decl.second.tensor - << "size=["; - for (auto size : buffered_decl.second.sizes) { - of << size << ","; - } - of << "]" << std::endl; - } - PrintHeader(of, "active_buffer_footprints"); - for (const auto &active_buffer_footprint : active_buffer_footprints_) { + for (const auto &active_buffer_footprint : analysis_result_.active_buffer_footprints_) { of << "cluster_id : " << active_buffer_footprint.second.cluster_id << std::endl << "domain : " << FormatMupaStr(active_buffer_footprint.first) << std::endl << "cluster : " << *(active_buffer_footprint.second.cluster) << std::endl @@ -436,81 +414,82 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) { } PrintHeader(of, "buffered_decl_infos"); - DumpBufferDefInfos(of); - of << std::endl; - - of << "custom_tiling : "; - if (custom_tiling_.empty()) of << "empty" << std::endl; - for (const auto &tiling : custom_tiling_) { - of << tiling << " "; - } + analysis_result_.DumpBufferDefInfos(of); of << std::endl; PrintHeader(of, "attr_info"); - for (const auto &info : attr_info_) { + for (const auto &info : cube_info_.GetConvAttrInfo()) { of << info.first << " : " << info.second << std::endl; } } -void Scop::DumpScopDataScheduleAttrs(std::ofstream &of) { +void UserConfig::DumpScopDataScheduleAttrs(std::ofstream &of) { PrintHeader(of, "schedule attrs"); - of << "dim : " << b_dim_ << std::endl; - of << "kernel_h : " << matB_dim_h_ << std::endl; - of << "kernel_w : " << matB_dim_w_ << std::endl; - of << "conv_backprop_filter : " << conv_back_prop_filter_ << std::endl; - of << "bypassL1 : " << bypassL1_ << std::endl; - of << "dump_tuning_level : " << dump_tuning_level_ << std::endl; - of << "pragma_rmselfdep : " << remove_self_dependence_ << std::endl; - of << "pragma_force_rmselfdep : " << force_remove_self_dependence_ << std::endl; - of << "pragma_reschedule : " << compute_reschedule_ << std::endl; - of << "pragma_disable_schedule_shift : " << disable_schedule_shift_ << std::endl; - of << "pragma_enable_schedule_max_constant : " << enable_schedule_max_constant_ << std::endl; - of << "pragma_disable_loop_reversal : " << disable_loop_reversal_ << std::endl; - of << "pragma_disable_loop_fusion : " << disable_loop_fusion_ << std::endl; - of << "pragma_modshift : " << mod_schedule_shift_ << std::endl; - of << "pragma_conv_special_dma : " << conv_special_dma_ << std::endl; - of << "pragma_reorder_schedule : " << reorder_schedule_ << std::endl; - of << "pragma_checkcoincident : " << tile_check_coincident_ << std::endl; - of << "pragma_opt_for_davinci : " << optimize_for_davinci_ << std::endl; - of << "pragma_sink_last_axis : " << sink_last_axis_ << std::endl; - of << "pragma_keep_outer_band_order : " << keep_outer_band_order_ << std::endl; - of << "pragma_disable_group : " << disable_group_ << std::endl; - of << "pragma_tile_inner_band : " << tile_inner_band_ << std::endl; - of << "kernel_name : " << kernel_name_ << std::endl; - of << "dump_poly_dir : " << dump_poly_dir_ << std::endl; - of << "isolated_idx : " << isolated_idx_ << std::endl; - of << "dynamic_shape_bound : " << dynamic_shape_bound_ << std::endl; - of << "pragma_tilesize_is_var : " << tile_size_is_var_ << std::endl; - of << "pragma_outerband_need_split : " << outer_band_need_split_ << std::endl; - of << "pragma_is_conv : " << pragma_is_conv_ << std::endl; + of << "dump_poly_dir : " << GetDumpPolyDir() << std::endl; + + of << "dump_tuning_level : " << GetDumpTuningLevel() << std::endl; + of << "dim : " << GetBDim() << std::endl; + + of << "pragma_rmselfdep : " << GetRemoveSelfDependence() << std::endl; + of << "pragma_force_rmselfdep : " << GetForceRemoveSelfDependence() << std::endl; + of << "pragma_reschedule : " << GetComputeReschedule() << std::endl; + of << "pragma_disable_schedule_shift : " << GetDisableScheduleShift() << std::endl; + of << "pragma_enable_schedule_max_constant : " << GetEnableScheduleMaxConstant() << std::endl; + of << "pragma_disable_loop_reversal : " << GetDisableLoopReversal() << std::endl; + of << "pragma_disable_loop_fusion : " << GetDisableLoopFusion() << std::endl; + of << "pragma_modshift : " << GetModScheduleShift() << std::endl; + of << "pragma_reorder_schedule : " << GetReorderSchedule() << std::endl; + of << "pragma_checkcoincident : " << GetTileCheckCoincident() << std::endl; + of << "pragma_opt_for_davinci : " << GetOptimizeForDavinci() << std::endl; + of << "pragma_sink_last_axis : " << GetSinkLastAxis() << std::endl; + of << "pragma_keep_outer_band_order : " << GetKeepOuterBandOrder() << std::endl; + of << "pragma_disable_group : " << GetDisableGroup() << std::endl; + of << "pragma_tile_inner_band : " << GetTileInnerBand() << std::endl; + of << "isolated_idx : " << GetIsolatedIdx() << std::endl; + of << "pragma_outerband_need_split : " << GetOuterBandNeedSplit() << std::endl; + + of << "dynamic_shape_bound : " << GetDynamicShapeBound() << std::endl; + of << "pragma_tilesize_is_var : " << GetTileSizeIsVar() << std::endl; + + of << "kernel_name : " << GetKernelName() << std::endl; + of << "kernel_h : " << GetMatBDimH() << std::endl; + of << "kernel_w : " << GetMatBDimW() << std::endl; + of << "conv_backprop_filter : " << GetConvBackPropFilter() << std::endl; + of << "bypassL1 : " << GetByPassL1() << std::endl; + of << "pragma_is_conv : " << GetPragmaIsConv() << std::endl; + of << "pragma_conv_special_dma : " << GetConvSpecialDma() << std::endl; } -bool Scop::DumpScopData(const std::string &file_name) { +bool ScopInfo::DumpScopData(const std::string &file_name) { std::string canonical_log_name = FilePathCanonicalize(file_name, true); if (!CreateFileIfNotExist(canonical_log_name)) return false; std::ofstream of; of.open(canonical_log_name, std::ios::out); if (!of.is_open()) return false; - DumpScopDataBasics(of); + analysis_result_.DumpScopDataBasics(of); DumpScopDataAdvanced(of); - DumpScopDataScheduleAttrs(of); + user_config_.DumpScopDataScheduleAttrs(of); of.close(); return true; } -void Scop::DumpSchTree(const std::string &file_name, const isl::schedule &sch_dump) { - if (dump_pass_ir_) { +void ScopInfo::DumpSchTree(const std::string &file_name, const isl::schedule &sch_dump) { + std::stringstream final_file_name; + final_file_name << std::setw(2) << std::setfill('0') << dump_schtree_count << "_" << file_name + << std::string(cube_info_.IsSpecGemm() ? "_specgemm" : ""); + if (user_config_.GetDumpPassIr()) { #if DUMP_IR - DumpSchTreeImpl(CreateDumpDir(file_name), sch_dump); + DumpSchTreeImpl(CreateDumpDir(final_file_name.str()), sch_dump); + dump_schtree_count++; #endif #if DUMP_SCOP_DATA #if DUMP_SCOP_DATA_PER_PASS - static_cast(DumpScopData(CreateDumpDir(file_name))); + static_cast(DumpScopData(CreateDumpDir(final_file_name.str()))); #else static_cast(DumpScopData(CreateDumpDir("scop"))); #endif @@ -518,29 +497,29 @@ void Scop::DumpSchTree(const std::string &file_name, const isl::schedule &sch_du } } -std::string Scop::AddDumpDir(const std::string &file_name) { +std::string ScopInfo::AddDumpDir(const std::string &file_name) { std::string real_file_name = file_name; - bool is_specgemm = (isolated_idx_ > 0); + bool is_specgemm = (user_config_.GetIsolatedIdx() > 0); if (is_specgemm) { - std::string dump_isolate_dir = "specgemm_" + std::to_string(isolated_idx_); + std::string dump_isolate_dir = "specgemm_" + std::to_string(user_config_.GetIsolatedIdx()); real_file_name = dump_isolate_dir + '/' + real_file_name; } #if (!DUMP_IN_CURRENT_DIR) - if (!dump_poly_dir_.empty()) { - real_file_name = dump_poly_dir_ + '/' + real_file_name; + if (!user_config_.GetDumpPolyDir().empty()) { + real_file_name = user_config_.GetDumpPolyDir() + '/' + real_file_name; } #endif return real_file_name; } -std::string Scop::CreateDumpDir(const std::string &file_name) { +std::string ScopInfo::CreateDumpDir(const std::string &file_name) { std::string real_file_name = AddDumpDir(file_name); CreateDirIfNotExist(real_file_name); return real_file_name; } -void Scop::DumpBufferDefInfos(std::ostream &out) { +void AnalysisResult::DumpBufferDefInfos(std::ostream &out) { for (size_t index = 0; index < buffer_def_infos_.size(); index++) { out << "\r\nbufferedDefInfos_[" << index << "]: " << std::endl; out << " tensor_id : " << buffer_def_infos_[index].tensor_id << std::endl; @@ -552,6 +531,48 @@ void Scop::DumpBufferDefInfos(std::ostream &out) { out << " is_bind_tensor : " << buffer_def_infos_[index].is_bind_tensor << std::endl; } } + +void ScopInfo::DumpTransform(const std::string &file_name, PassInfo &pass_info) { + auto real_path = CreateDumpDir(file_name); + std::ofstream of; + of.open(real_path, std::ios::out); + if (!of.is_open()) { + return; + } + + PrintHeader(of, "group_filter_map"); + for (const auto &group : pass_info.group_filter_map_) { + of << group.first << " : [ "; + for (auto filter : group.second) { + of << filter << ", "; + } + of << "]" << std::endl; + } + + PrintHeader(of, "dependences"); + of << FormatMupaStr(pass_info.dependences_.to_str()) << std::endl; + + PrintHeader(of, "constraints"); + isl_printer *p; + char *s = nullptr; + p = isl_printer_to_str(GetCtx().get()); + CHECK(p != nullptr); + p = isl_printer_set_yaml_style(p, ISL_YAML_STYLE_BLOCK); + p = isl_printer_print_schedule_constraints(p, pass_info.constraints_.get()); + s = isl_printer_get_str(p); + if (s) { + of << FormatMupaStr(s); + free(s); + } + static_cast(isl_printer_free(p)); + + PrintHeader(of, "time_records"); + for (auto time_log : time_records_) { + of << time_log << std::endl; + } + + of.close(); +} } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/dump_log.h b/src/poly/dump_log.h index 48d5ca400c4d89240dff838a0d494569cdba97d8..c9d59303d5f538b8c22b375103bb3ba30bad3f05 100644 --- a/src/poly/dump_log.h +++ b/src/poly/dump_log.h @@ -19,6 +19,7 @@ #include #include #include +#include "poly/poly_util.h" namespace akg { namespace ir { namespace poly { @@ -35,11 +36,11 @@ bool CreateFileIfNotExist(const std::string &file_name); void CreateDirIfNotExist(const std::string &file_name); std::string DumpSchTreeToString(const isl::schedule &sch); void DumpSchTreeImpl(const std::string &file_name, const isl::schedule &sch); +std::string PrettyPrintSchTree(const isl::schedule &sch); void PrintHeader(std::ofstream &of, const std::string &str); +void PrintHeader(const std::string &str); void DumpNode(std::ofstream &of, const air::Node *node); -bool CompareSchTreeWithString(const std::string &compare_sch, const isl::schedule &sch); - } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/isl_emitter.cc b/src/poly/isl_emitter.cc index b273addff94789a8b610be1c1068cff8f7a3387a..6fb6c7562906540afaa2968dc88cac0d7c2952aa 100644 --- a/src/poly/isl_emitter.cc +++ b/src/poly/isl_emitter.cc @@ -203,11 +203,13 @@ Stmt IslEmitter::EmitFor(const isl::ast_node_for &node) { Stmt IslEmitter::EmitIf(const isl::ast_node_if &node) { Expr cond_expr = Interpret(node.get_cond()); + cur_if_list_.push_back(cond_expr.get()); Stmt then_case = EmitAst(node.get_then_node()); Stmt else_case; if (node.has_else_node()) { else_case = EmitAst(node.get_else_node()); } + cur_if_list_.pop_back(); return IfThenElse::make(cond_expr, then_case, else_case); } @@ -230,25 +232,8 @@ Stmt IslEmitter::EmitBlock(const isl::ast_node_block &node) { } } -class ReplaceLoopVar : public air::ir::IRMutator { - public: - explicit ReplaceLoopVar(VarMap v_) : var_map(std::move(v_)) {} - ~ReplaceLoopVar() override = default; - Expr Mutate_(const Variable *op, const Expr &e) final { - for (auto &i : var_map) { - if (op->name_hint == i.first.get_name()) { - return i.second; - } - } - return e; - } - - private: - VarMap var_map; -}; - isl::space IslEmitter::GetDomainSpace(const isl::id &node_id) { - auto dom = isl::union_set(scop_.Domain()); + auto dom = isl::union_set(info_.analysis_result_.Domain()); auto space = isl::space(); dom.foreach_set([&node_id, &space](const isl::set &s) -> void { if (s.get_tuple_id() == node_id) { @@ -265,12 +250,12 @@ isl::space IslEmitter::GetSpace(const isl::id &tensor_id, const Array &ten return space; } -isl::multi_aff IslEmitter::TensorAccessMultAff(const isl::id &tensor_id, const Array &tensor_index, +isl::multi_aff IslEmitter::TensorAccessMultAff(isl::id &tensor_id, const Array &tensor_index, const isl::id &node_id) { CHECK_NE(tensor_index.size(), 0u); isl::pw_multi_aff iter_map = node_info_map_.at(node_id).iterator_map; isl::id stmt_id = iter_map.get_tuple_id(isl_dim_out); - OperatorDomainSpace domain_space = scop_.data_.domains.at(stmt_id); + OperatorDomainSpace domain_space = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id); isl::multi_aff ma = isl::multi_aff::zero(GetSpace(tensor_id, tensor_index, stmt_id)); for (size_t i = 0; i < tensor_index.size(); ++i) { auto aff = Expr2Aff(domain_space.param_space, tensor_index[i]).unbind_params_insert_domain(domain_space.tuple); @@ -335,8 +320,9 @@ class EmitExpr : public air::ir::IRMutator { Map cache_; }; -void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info, - std::vector active_buf_footprints, isl::id fp_id) { +BufferedFootPrintInfo FindBufferFootprintById(const std::vector &active_buf_footprints, + const isl::id &fp_id) { + BufferedFootPrintInfo buffer_footprint_info; for (const auto &act_buf_fp : active_buf_footprints) { if (act_buf_fp.cluster != nullptr) { for (const auto &fp : act_buf_fp.cluster->tensor_foot_prints) { @@ -347,14 +333,16 @@ void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info, } } } + return buffer_footprint_info; } -bool IsTransferStmt(Scop &scop, isl::id &stmt_id) { - if (!scop.is_spec_gemm_ && scop.is_tiled_) { - isl::union_set transfer_stmt = scop.data_.transfer_stmt; +bool IslEmitter::IsTransferStmt() { + if (info_.analysis_result_.GetIsTiled()) { + isl::union_set transfer_stmt = info_.analysis_result_.GetTransferStmt(); if (!transfer_stmt.is_empty()) { bool name_match = false; - transfer_stmt.foreach_set([&name_match, stmt_id](const isl::set &s) -> void { + auto stmt_id = stmt_id_; + transfer_stmt.foreach_set([&name_match, &stmt_id](const isl::set &s) -> void { if (s.get_tuple_name() == stmt_id.get_name()) { name_match = true; } @@ -365,8 +353,8 @@ bool IsTransferStmt(Scop &scop, isl::id &stmt_id) { return false; } -Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, - Scop::BufferedFootPrintInfo &buffer_footprint_info) { +Stmt IslEmitter::EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, + BufferedFootPrintInfo &buffer_footprint_info) { const auto provide = static_cast(node); Expr value = ReplaceLoopVar(var_map_tmp).Mutate(provide->value); Array args; @@ -380,8 +368,8 @@ Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, return Stmt(); } -Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_footprint_info, - bool &is_transfer_stmt, Scop &scop) { +Stmt IslEmitter::EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, + BufferedFootPrintInfo &buffer_footprint_info) { const Call *call = static_cast(node); Array args; for (auto iv : call->args) { @@ -389,46 +377,35 @@ Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::Buffe } // Not hoisted, emitting just the mapped subscript. if (!buffer_footprint_info.cluster_id) { - std::string call_name = call->name; - if (is_transfer_stmt && (std::string::npos == call_name.find("_local_UB"))) { - call_name = call_name + "_local_UB"; - Tensor t = scop.FindTensor(call_name); - if (t.defined()) { - return Evaluate::make(Call::make(call->type, call_name, args, call->call_type, t->op, call->value_index)); - } else { - LOG(WARNING) << "Call can not found tensor!!! tensor name: " << call_name; - } - } return Evaluate::make(Call::make(call->type, call->name, args, call->call_type, call->func, call->value_index)); } return Stmt(); } -bool IsCopyinFromAnotherBand(Scop &scop, isl::multi_aff &access) { - if (!scop.is_spec_gemm_) { - for (isl::map inter_band_dependency : scop.data_.inter_band_dependency.get_map_list()) { - if (inter_band_dependency.get_tuple_id(isl_dim_out) == access.get_tuple_id(isl_dim_out)) { - return true; - } +bool IslEmitter::IsCopyinFromAnotherBand(isl::multi_aff &access) { + for (isl::map inter_band_dependency : info_.analysis_result_.GetInnerBandDependency().get_map_list()) { + if (inter_band_dependency.get_tuple_id(isl_dim_out) == access.get_tuple_id(isl_dim_out)) { + return true; } } return false; } -void AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool &is_transfer_stmt, - bool &is_copyin_from_another_band) { +isl::pw_multi_aff &AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool is_transfer_stmt, + bool is_copyin_from_another_band) { if (is_transfer_stmt || is_copyin_from_another_band) { isl_pw_multi_aff *pma1 = ast_to_schedule.copy(); isl_pw_multi_aff *pma2 = ast_to_schedule.copy(); isl_pw_multi_aff *pma = isl_pw_multi_aff_sub(pma1, pma2); ast_to_schedule = isl::manage(pma); } + return ast_to_schedule; } -Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, const Node *node, Array &args) { +Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array &args) { const auto provide = static_cast(node); - Tensor t = scop.FindTensor(var); - if (scop.CountBufferDefInfo(var)) { + Tensor t = info_.FindTensor(var); + if (info_.analysis_result_.CountBufferDefInfo(var)) { realize_may_def_.insert(var); if_map_[var] = cur_if_list_; if (cur_if_list_.empty()) { @@ -439,10 +416,10 @@ Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, co return s; } -Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const Node *node, Array &args) { +Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array &args) { const Call *call = static_cast(node); - Tensor t = scop.FindTensor(var); - if (scop.CountBufferDefInfo(var)) { + Tensor t = info_.FindTensor(var); + if (info_.analysis_result_.CountBufferDefInfo(var)) { realize_use_.insert(var); if (!if_map_.count(var) || !AOutThanB(if_map_.at(var), cur_if_list_)) { realize_use_with_may_def_.insert(var); @@ -451,25 +428,6 @@ Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const return Evaluate::make(Call::make(call->type, var.get_name(), args, call->call_type, t->op, t->value_index)); } -void GetNameWithoutLocal(isl::id &tensor_id, Scop &scop) { - if (!scop.is_spec_gemm_) { - size_t pos = tensor_id.get_name().find("_local_"); - std::string substr = tensor_id.get_name().substr(0, pos); - if (pos != 0) tensor_id = isl::id(tensor_id.ctx(), substr); - } -} - -Stmt EmitAccessNodeImpl(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_footprint_info, - bool &is_transfer_stmt, Scop &scop, bool is_Provide) { - Stmt s; - if (is_Provide) { - s = EmitAccessNodeProvide(node, var_map_tmp, buffer_footprint_info); - } else { - s = EmitAccessNodeCall(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop); - } - return s; -} - Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const Array &tensor_index, const VarMap &var_map_tmp) { // Scalars are not hoisted or remapped. @@ -481,40 +439,34 @@ Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const auto build = node_info_map_.at(node_id_).build; auto iterator_map = node_info_map_.at(node_id_).iterator_map; - CHECK_EQ(scop_.data_.accesses.count(node), 1u) + CHECK_EQ(info_.analysis_result_.GetAccessMap().count(node), 1u) << "generating tensor " << name << " not in Scop" << node << " not allowed "; - auto fp_id = scop_.data_.accesses.at(node); + auto fp_id = info_.analysis_result_.GetAccessMap().at(node); - Scop::BufferedFootPrintInfo buffer_footprint_info; - std::vector active_buf_footprint; - for (const auto &kv : scop_.ActiveBufferFootprints()) { + std::vector active_buf_footprint; + for (const auto &kv : info_.analysis_result_.ActiveBufferFootprints()) { if (kv.first.intersect(isl::union_set(Domain())).is_empty()) { continue; } active_buf_footprint.emplace_back(kv.second); } - FindBufferFootprintById(buffer_footprint_info, active_buf_footprint, fp_id); - - bool is_transfer_stmt = false; - is_transfer_stmt = IsTransferStmt(scop_, stmt_id_); + BufferedFootPrintInfo buffer_footprint_info = FindBufferFootprintById(active_buf_footprint, fp_id); if (node->IsInstance()) { - if (EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, true).defined()) - return EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, true); + auto stmt = EmitAccessNodeProvide(node, var_map_tmp, buffer_footprint_info); + if (stmt.defined()) return stmt; } if (node->IsInstance()) { - if (EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, false).defined()) - return EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, false); + auto stmt = EmitAccessNodeCall(node, var_map_tmp, buffer_footprint_info); + if (stmt.defined()) return stmt; } - auto buf_def = scop_.GetBufferDefInfo(buffer_footprint_info.cluster_id); - GetNameWithoutLocal(buf_def.tensor_id, scop_); + auto buf_def = info_.analysis_result_.GetBufferDefInfo(buffer_footprint_info.cluster_id); auto access = TensorAccessMultAff(buf_def.tensor_id, tensor_index, node_id_); - bool is_copyin_from_another_band = false; - is_copyin_from_another_band = IsCopyinFromAnotherBand(scop_, access); + bool is_copyin_from_another_band = IsCopyinFromAnotherBand(access); auto memory_hoist = buffer_footprint_info.cluster->ComputeBufferedFootprints(); if (is_copyin_from_another_band) { @@ -523,33 +475,31 @@ Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const // split read-only or write-only input tensor memory_hoists // we need to find tensor by name because tensor_id is a fake isl::id - bool is_input_tensor = scop_.FindTensorInOrig(buf_def.tensor_id.name()).defined(); + bool is_input_tensor = info_.FindTensorInOrig(buf_def.tensor_id.name()).defined(); if (is_input_tensor && buffer_footprint_info.cluster->foot_print_.should_split) { memory_hoist = buffer_footprint_info.cluster->UnshiftedBufferFootprint(memory_hoist, fp_id); } memory_hoist = memory_hoist.set_tuple_id(isl_dim_out, buffer_footprint_info.cluster_id); - auto schedule = isl::map::from(buffer_footprint_info.outer_schedule.intersect_domain(this->Domain())); + auto schedule = isl::map::from(buffer_footprint_info.outer_schedule.intersect_domain(Domain())); CHECK(schedule.is_single_valued()) << schedule << " is not single-valued schedule"; auto ast_to_schedule = isl::pw_multi_aff(schedule).pullback(iterator_map); - AffSubForAstToSchedule(ast_to_schedule, is_transfer_stmt, is_copyin_from_another_band); + ast_to_schedule = AffSubForAstToSchedule(ast_to_schedule, IsTransferStmt(), is_copyin_from_another_band); auto ast_to_original = isl::pw_multi_aff(access).pullback(iterator_map); auto ast_to_scheduled_original = ast_to_schedule.range_product(ast_to_original); auto ast_to_hoisted = isl::pw_multi_aff(memory_hoist).pullback(ast_to_scheduled_original); auto hoist_acs = build.access_from(ast_to_hoisted); if (auto op = hoist_acs.as()) { - if (auto access_ = op.as()) { + if (op.as()) { Array args; for (int i = 1; i < static_cast(op.get_n_arg()); ++i) { args.push_back(Interpret(op.get_arg(i))); } if (node->IsInstance()) - return IslEmitter::EmitAccessNodeFromPromoteAcsProvide(scop_, op.get_arg(0).as().get_id(), - node, args); + return EmitAccessNodeFromPromoteAcsProvide(op.get_arg(0).as().get_id(), node, args); if (node->IsInstance()) - return IslEmitter::EmitAccessNodeFromPromoteAcsCall(scop_, op.get_arg(0).as().get_id(), node, - args); + return EmitAccessNodeFromPromoteAcsCall(op.get_arg(0).as().get_id(), node, args); } } return Evaluate::make(Expr("todo EmitAst")); @@ -569,7 +519,7 @@ Stmt IslEmitter::EmitUserStmtContent(const Evaluate *eva_node) { auto im2col = Call::make(call->type, call->name, args, call->call_type); Stmt res = Evaluate::make(im2col); // add AttrStmt to im2col - for (const auto &item : scop_.data_.vecs) { + for (const auto &item : info_.analysis_result_.GetBufferBindVec()) { Expr replaced = ReplaceLoopVar(var_map_).Mutate(item.second); res = AttrStmt::make(item.first, air::ir::attr::buffer_bind_scope, replaced, res); } @@ -600,8 +550,9 @@ class SubstituteByNameMutator : public IRMutator { * So, we need to sink the copy out statement into the innermost "if", * i.e., copy out immediately after each computation. */ -static Stmt GenerateCopyOut(const Scop &scop, const Provide *original, const Provide *hoisted, const VarMap &var_map) { - auto call_type = scop.GetDtypeOf(hoisted->func->func_name()); +static Stmt GenerateCopyOut(const ScopInfo &info, const Provide *original, const Provide *hoisted, + const VarMap &var_map) { + auto call_type = info.GetDtypeOf(hoisted->func->func_name()); Expr call_expr = Call::make(call_type, hoisted->func->func_name(), hoisted->args, Call::CallType::Halide, hoisted->func, hoisted->value_index); Array new_args; @@ -621,8 +572,8 @@ Stmt IslEmitter::EmitUserStmtContent(const Provide *provide_node) { Expr value = EmitExpr(f, var_map_).Mutate(provide_node->value); Stmt provide_stmt = Provide::make(provide_new->func, provide_new->value_index, value, provide_new->args); - if (scop_.conditional_write_buffer_footprints_.count(write_tensor)) { - return Block::make(provide_stmt, GenerateCopyOut(scop_, provide_node, provide_new, var_map_)); + if (info_.analysis_result_.GetConditionalWriteBufferFootprints().count(write_tensor)) { + return Block::make(provide_stmt, GenerateCopyOut(info_, provide_node, provide_new, var_map_)); } return provide_stmt; } @@ -688,11 +639,11 @@ Stmt IslEmitter::EmitUserStmt(const isl::ast_node_user &node) { isl::ast_expr_op usr_expr = node.get_expr().as(); stmt_id_ = usr_expr.get_arg(0).as().get_id(); node_id_ = node.get_annotation(); - const Node *stmt_node = scop_.data_.statements.at(stmt_id_); + const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(stmt_id_); CHECK(stmt_node); // compute VarMap to replace old iterators auto build = node_info_map_.at(node_id_).build; - auto tuple = scop_.data_.domains.at(stmt_id_).tuple; + auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id_).tuple; auto iterator_map = node_info_map_.at(node_id_).iterator_map; var_map_.clear(); @@ -701,41 +652,51 @@ Stmt IslEmitter::EmitUserStmt(const isl::ast_node_user &node) { auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i)); Expr halide_new_iter = Interpret(isl_expr); var_map_.emplace(isl_old_iter, halide_new_iter); - std::string replace_id = isl_old_iter.get_name() + "_"; - std::vector vec = ExtractIterfromExpr().Run(halide_new_iter); - for (auto item : vec) { - std::string new_name = item->name_hint; - size_t pos = new_name.find(scop_.iter_prefix_); - if (pos != std::string::npos) { - new_name = new_name.replace(pos, scop_.iter_prefix_.size(), replace_id); - iters_old_name_.emplace(item, item->name_hint); - iters_new_name_.emplace(item, new_name); - } - } } - VarMap vmap = var_map_; - stmt_var_map_.emplace(stmt_id_, vmap); return EmitUserStmtContent(stmt_node); } -Stmt IslEmitter::EmitStmt(const isl::ast_node_user &node) { return EmitUserStmt(node); } +Stmt IslEmitter::EmitStmt(const isl::ast_node_user &node) { + CHECK(node.get_expr().isa()); + isl::ast_expr_op usr_expr = node.get_expr().as(); + CHECK(usr_expr); + auto stmt_id = usr_expr.get_arg(0).as().get_id(); + if (info_.IsRead(stmt_id)) { + return Evaluate::make(Expr("todo EmitRead")); + } + if (info_.IsWrite(stmt_id)) { + return Evaluate::make(Expr("todo EmitWrite")); + } + return EmitUserStmt(node); +} Stmt IslEmitter::EmitAst(const isl::ast_node &node) { + Stmt s; + std::string info; if (auto for_node = node.as()) { - return EmitFor(for_node); + info = "[FOR_NODE]"; + s = EmitFor(for_node); } else if (auto if_node = node.as()) { - return EmitIf(if_node); + info = "[IF_NODE]"; + s = EmitIf(if_node); } else if (auto block_node = node.as()) { - return EmitBlock(block_node); + info = "[BLOCK_NODE]"; + s = EmitBlock(block_node); } else if (auto mark_node = node.as()) { - return EmitMark(mark_node); + info = "[MARK_NODE]"; + s = EmitMark(mark_node); } else if (auto user_node = node.as()) { - return EmitStmt(user_node); + info = "[USER_NODE]"; + s = EmitStmt(user_node); } else { - LOG(FATAL) << "NYI " << node << "\n"; + s = Evaluate::make(Expr("todo EmitAst")); } - return Evaluate::make(Expr("todo EmitAst")); + if (PRINT_EMMITER) { + LOG(INFO) << ">>>>>>>>>>>>INPUT AST_NODE" << info << "<<<<<<<<<<<<<<\n" << node; + LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << s; + } + return s; } Stmt IslEmitter::Emit(const isl::ast_node &node) { return EmitAst(node); } diff --git a/src/poly/isl_emitter.h b/src/poly/isl_emitter.h index af9b0ed155be637e517e5a8dff9278632c41751d..f9dbe61de076768fbaa7c67bb29add2790a3c110 100644 --- a/src/poly/isl_emitter.h +++ b/src/poly/isl_emitter.h @@ -19,11 +19,9 @@ #include #include -#include #include "ir_pass.h" -#include "poly/isl.h" -#include "poly/scop.h" +#include "poly/scop_info.h" namespace akg { namespace ir { @@ -47,29 +45,31 @@ class IslEmitter { Expr InterpretBinaryOp(const isl::ast_expr_op &e); public: - explicit IslEmitter(Scop &s_, const NodeInfoRepo &n_, const isl::id_list &i_) - : scop_(s_), node_info_map_(n_), iter_names_(i_) {} + explicit IslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) + : info_(info), node_info_map_(n), iter_names_(i) {} virtual ~IslEmitter() = default; - /// Interpret isl::ast_expr to Halide Expr - //@{ + // Interpret isl::ast_expr to Halide Expr Expr Interpret(const isl::ast_expr &e); - //@} // helper functions, which may can be moved into a separated class isl::space GetDomainSpace(const isl::id &stmt_id); isl::space GetSpace(const isl::id &tensor_id, const Array &tensor_index, const isl::id &stmt_id); - isl::multi_aff TensorAccessMultAff(const isl::id &tensor_id, const Array &subscripts, const isl::id &stmt_id); isl::set Domain() const { auto iterator_map = node_info_map_.at(node_id_).iterator_map; return isl::map::from(iterator_map).range(); } Stmt EmitAccessNode(const std::string &name, const Node *node, const Array &tensor_index, const VarMap &var_map_tmp); - Stmt EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, const Node *node, Array &args); - Stmt EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const Node *node, Array &args); - /// Virtual emitters for different type node - //@{ + Stmt EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array &args); + Stmt EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array &args); + Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info); + virtual Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info); + virtual isl::multi_aff TensorAccessMultAff(isl::id &tensor_id, const Array &subscripts, const isl::id &stmt_id); + virtual bool IsTransferStmt(); + virtual bool IsCopyinFromAnotherBand(isl::multi_aff &access); + + // Virtual emitters for different type node virtual Stmt Emit(const isl::ast_node &node); virtual Stmt EmitFor(const isl::ast_node_for &node); virtual Stmt EmitIf(const isl::ast_node_if &node); @@ -84,7 +84,12 @@ class IslEmitter { virtual Stmt EmitUserStmtContent(const IfThenElse *if_node); virtual Stmt EmitUserStmtContent(const For *for_node); virtual Stmt EmitUserStmtContent(const Block *block_node); - //@} + + // Loop isl iters info + virtual void PushIter(const Variable *iter); + virtual void PopIter(const Variable *iter); + bool FindIter(const Variable *iter) const; + const Variable *GetIterByName(const std::string &id) const; std::unordered_set realize_use_; std::unordered_set realize_use_with_may_def_; @@ -93,28 +98,16 @@ class IslEmitter { std::unordered_set realize_out_; std::unordered_set global_realize_out_; - /// Scop - Scop &scop_; + ScopInfo &info_; /// Node information map including const NodeInfoRepo &node_info_map_; - /// Loop isl iters info - //@{ /// Loop isl iters list isl::id_list iter_names_; /// Loop declared halide iters std::vector iters_; - virtual void PushIter(const Variable *iter); - virtual void PopIter(const Variable *iter); - bool FindIter(const Variable *iter) const; - const Variable *GetIterByName(const std::string &id) const; - //@} - - std::map iters_old_name_; - std::map iters_new_name_; - // current ast node id isl::id node_id_; // current stmt id @@ -125,7 +118,6 @@ class IslEmitter { // emit in if std::vector cur_if_list_; std::unordered_map, isl::IslIdIslHash> if_map_; - std::unordered_map stmt_var_map_; }; class ExtractIterfromExpr : public air::ir::IRVisitor { @@ -146,16 +138,23 @@ class ExtractIterfromExpr : public air::ir::IRVisitor { std::vector vec_; }; -void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info, - std::vector active_buffer_fp, isl::id id); -void GetNameWithoutLocal(isl::id &tensor_id, Scop &scop); -bool IsTransferStmt(Scop &scop, isl::id &stmt_id); -bool IsCopyinFromAnotherBand(Scop &scop, isl::multi_aff &access); -void AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool &is_transfer_stmt, - bool &is_copyin_from_another_band); -Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_fp_info); -Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_fp_info, - bool &is_transfer_stmt, Scop &scop); +class ReplaceLoopVar : public air::ir::IRMutator { + public: + explicit ReplaceLoopVar(VarMap v_) : var_map(std::move(v_)) {} + ~ReplaceLoopVar() override = default; + Expr Mutate_(const Variable *op, const Expr &e) final { + for (auto &i : var_map) { + if (op->name_hint == i.first.get_name()) { + return i.second; + } + } + return e; + } + + private: + VarMap var_map; +}; + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/memory_manager.cc b/src/poly/memory_manager.cc deleted file mode 100644 index 148aa8912a415278a7f4dc4b4469f54557555d22..0000000000000000000000000000000000000000 --- a/src/poly/memory_manager.cc +++ /dev/null @@ -1,798 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "scop.h" -#include "poly/dma_inject.h" -#include "scop_builder.h" - -namespace akg { -namespace ir { -namespace poly { -PartitionSingle::PartitionSingle(int times, int tile_start, int cut_m, - const std::map &fractal_int_info) { - m_times_ = times; - m_cut_m_ = cut_m; - m_fractal_int_info_ = fractal_int_info; -} - -PartitionSingle *PartitionSingle::single_ = nullptr; -int PartitionSingle::m_times_ = 0; -int PartitionSingle::m_cut_m_ = 0; -std::map PartitionSingle::m_fractal_int_info_; - -void GetVisitedStmts(const isl::schedule_node &root) { - int n = root.n_children(); - if (n <= 0) return; - - isl::schedule_node node; - if (root.isa()) { - isl::union_set visited_stmts; - for (int i = 0; i < n; ++i) { - node = root.child(i); - auto filter_node = node.as(); - CHECK(filter_node) << "expected children of sequence to be filters"; - auto filter = filter_node.get_filter().universe(); - if (visited_stmts.get()) { - CHECK(visited_stmts.intersect(filter).is_empty()) << "filters are expected to be disjoint as stmt level"; - visited_stmts = visited_stmts.unite(filter); - } else { - visited_stmts = filter; - } - } - } - - for (int i = 0; i < n; ++i) { - node = root.child(i); - GetVisitedStmts(node); - } -} - -std::vector CollectMarkNode(const isl::schedule_node &tree, const std::string &mark_tag) { - std::vector mark_nodes; - tree.foreach_descendant_top_down([&mark_nodes, &mark_tag](const isl::schedule_node &node) -> bool { - if (auto mark_node = node.as()) { - // ignore nested mark nodes - if (mark_node.get_id().get_name() == mark_tag) { - mark_nodes.push_back(node); - return false; - } - } - return true; - }); - return mark_nodes; -} - -const BufferDefInfo &Scop::GetBufferDefInfo(const isl::id &tensor_id) const { - for (const auto &idx : BufferDefInfos()) { - if (idx.dst_tensor_id.get_name() == tensor_id.get_name()) { - return idx; - } - } - LOG(FATAL) << "Hoist footprint of tensor " << tensor_id << " has no buffer definition"; - return place_holder_; -} - -void Scop::RecordAllTensorBufferFootprintToExtension() { - GetVisitedStmts(schedule_.get_root()); - for (size_t index = 0; index < buffer_def_infos_.size(); index++) { - if (buffer_def_infos_[index].find_buffer) continue; - std::string mark_tag = buffer_def_infos_[index].mark_tag; - if (buffer_def_infos_[index].IsIm2col()) { - isl::id nextTensorId = buffer_def_infos_[index].NextTensorDstId(); - mark_tag = GetBufferDefInfo(nextTensorId).mark_tag; - } - this->schedule_ = HoistBufferFootprintAtMarkNode(schedule_.get_root(), mark_tag, index); - } - CHECK_EQ(buffer_footprint_queue_.size(), 0); -} - -isl::schedule_node MapDescendantTopDown(isl::schedule_node node, - const std::function &fn) { - unsigned int depth_ = node.get_tree_depth(); - do { - do { - node = fn(node); - } while (node.has_children() && (node = node.first_child())); - - while (node.get_tree_depth() > depth_ && !node.has_next_sibling()) { - node = node.parent(); - } - - if (node.get_tree_depth() > depth_) { - node = node.next_sibling(); - } - } while (node.get_tree_depth() > depth_); - - return node; -} - -isl::schedule Scop::HoistBufferFootprintAtMarkNode(const isl::schedule_node &root, const std::string &mark_tag, - size_t index) { - auto fn = [mark_tag, index, this](isl::schedule_node node) -> isl::schedule_node { - if (node.isa()) { - std::string mark_id = node.as().get_id().get_name(); - if (mark_id == mark_tag) { - node = HoistBufferFootprintAtMarkNode(node.get_child(0), index); - } - } - return node; - }; - - return MapDescendantTopDown(root, fn).get_schedule(); -} - -isl::schedule_node Scop::HoistBufferFootprintAtMarkNode(const isl::schedule_node &tree, size_t index) { - auto schedule = LocalSchedule(tree); - - // hoist cluster and add extension to schedule tree - return HoistTensorClusterFootprint(tree, index, schedule); -} - -void Scop::GatherBufferFootprintDefInfo(const isl::schedule_node &tree, BufferDefInfo &tensor_info) { - auto fp_cluster = tensor_info.GetFootPrintCluster(tree); - std::vector sizes; - if (fp_cluster == nullptr) { - tensor_info.AddSize(tree, sizes); - return; - } - sizes = fp_cluster->GetFixedBoxSizes(); - - isl::id tensor_id = tensor_info.tensor_id; - isl::id cluster_id = tensor_info.dst_tensor_id; - - // build a Halide Node for cluster_id - Array shapes; - for (auto i : sizes) { - shapes.push_back(Expr(static_cast(i))); - } - - Type type = GetDtypeOf(tensor_id); - Tensor tensor = placeholder(shapes, type, cluster_id.get_name()); - const Buffer buffer = decl_buffer(shapes, GetDtypeOf(tensor_id), cluster_id.get_name()); - binds_.Set(tensor, buffer); - - tensor_info.sizes = sizes; - tensor_info.tensor = tensor; - tensor_info.data_type = type; - tensor_info.AddSize(tree, sizes); -} - -void Scop::CollectBufferFootprintDefInfo(BufferDefInfo &tensor_info, const isl::union_map &schedule_prom, - const isl::schedule_node &node) { - tensor_info.footprints_cluster = TensorFootprintCluster::HoistBufferFootprintCluster( - schedule_prom, tensor_info.ancester_tensor_id, data_.reads, data_.copyin, data_.writes, data_.fake_copyin); - if (tensor_info.footprints_cluster != nullptr) { - tensor_info.footprint_cluster_map.emplace_back(std::make_pair(node, tensor_info.footprints_cluster)); - GatherBufferFootprintDefInfo(node, tensor_info); - } -} - -void Scop::HoistIm2colBufferFootprintCluster(const isl::union_map &schedule, const isl::schedule_node &node, - const int index, BufferDefInfo &tensor_info) { - im2col_fp_cluster = ConstructAffineFpCluster(*this, data_.reads, schedule.domain(), schedule, ReferenceType::Read, - AffineType::AFFINE_IM2COL); - tensor_info.footprints_cluster = ConstructAffineFpCluster(*this, data_.reads, schedule.domain(), schedule, - ReferenceType::Read, AffineType::AFFINE_FRACTAL); - CHECK_EQ(index, 0); - CHECK(im2col_fp_cluster != nullptr) << "im2col_fp_cluster must be not null"; - CHECK(tensor_info.footprints_cluster != nullptr) << "footprint cluster in Im2col must be defined"; - tensor_info.footprint_cluster_map.emplace_back(std::make_pair(node, tensor_info.footprints_cluster)); - - if ((tensor_info.footprints_cluster->foot_print_.box.is_valid()) && (im2col_fp_cluster->foot_print_.box.is_valid())) { - GatherBufferFootprintDefInfo(node, tensor_info); - // this update info is used for spec gemm - UpdateFractalIntFirstInfo(IsConvBackpropFilter(), im2col_fp_cluster->GetFixedBoxSizes(), - tensor_info.footprints_cluster->GetFixedBoxSizes()); - } else { - int64_t t_ci = 1; - int64_t k_h = 0; - int64_t k_w = 0; - int64_t t_h = 1; - int64_t t_w = 1; - int64_t s_h = 1; - int64_t s_w = 1; - int64_t t_ho = 1; - int64_t t_wo = 1; - int64_t c_in = 0; - LOG(INFO) << "im2col or fractal foot_print_ box is invalid."; - - auto it = attr_info_.find(ATTR_CONV_KERNEL_H); - if ((it != attr_info_.end()) && (*it).second.as()) k_h = (*it).second.as()->value; - it = attr_info_.find(ATTR_CONV_KERNEL_W); - if ((it != attr_info_.end()) && (*it).second.as()) k_w = (*it).second.as()->value; - it = attr_info_.find(ATTR_CONV_STRIDE_H); - if ((it != attr_info_.end()) && (*it).second.as()) s_h = (*it).second.as()->value; - it = attr_info_.find(ATTR_CONV_STRIDE_W); - if ((it != attr_info_.end()) && (*it).second.as()) s_w = (*it).second.as()->value; - it = attr_info_.find(ATTR_CONV_TILE_H); - if ((it != attr_info_.end()) && (*it).second.as()) t_h = (*it).second.as()->value; - it = attr_info_.find(ATTR_CONV_TILE_W); - if ((it != attr_info_.end()) && (*it).second.as()) t_w = (*it).second.as()->value; - it = attr_info_.find(ATTR_CONV_FEATURE_C); - if ((it != attr_info_.end()) && (*it).second.as()) c_in = (*it).second.as()->value; - - t_ho = (t_h - k_h) / s_h + 1; - t_wo = (t_w - k_w) / s_w + 1; - - bool replace_ci = false; - if (!dynamic_shape_.empty()) { - for (const auto &ds : dynamic_shape_) { - if (auto dsn = ds.as()) { - if (dsn->tensor_name == "CI1") { - t_ci = (int64_t)(dsn->poly_upper_bound - 1); - replace_ci = true; - } - } - } - } - if (!replace_ci) { - t_ci = (int64_t)(c_in + 15) / 16; - } - - std::vector sizes; - sizes.push_back(1); // 1 - sizes.push_back((size_t)((t_ho * t_wo + 15) / 16)); // 109 - sizes.push_back((size_t)(t_ci * k_h * k_w)); // 43648 - sizes.push_back(16); // 16 - sizes.push_back(16); // 16 - fractal_int_info_[ATTR_CONV_GMM_M] = t_ho * t_wo; // 1739 - fractal_int_info_[ATTR_CONV_BATCH] = (int64_t)sizes[0]; - fractal_int_info_[ATTR_CONV_TILE_M] = (int64_t)sizes[1]; - fractal_int_info_[ATTR_CONV_TILE_K] = (int64_t)sizes[2]; - fractal_int_info_[ATTR_CONV_M_INNER] = (int64_t)sizes[3]; - fractal_int_info_[ATTR_CONV_K_INNER] = (int64_t)sizes[4]; - GatherFractalDefInfo(node, tensor_info, sizes); - } - fractal_int_info_[ATTR_CONV_FEATURE_W] = ExtractExprFromAttrs(ATTR_CONV_FEATURE_W); - fractal_int_info_[ATTR_CONV_PAD_LEFT] = ExtractExprFromAttrs(ATTR_CONV_PAD_LEFT); - fractal_int_info_[ATTR_CONV_PAD_RIGHT] = ExtractExprFromAttrs(ATTR_CONV_PAD_RIGHT); -} - -void Scop::MakeMultiBufferFootprint(const isl::union_map &schedule, const isl::schedule_node &node, int &index, - BufferDefInfo &tensor_info) { - if (!IsCopyinTensor(tensor_info.ancester_tensor_id.get_name())) { - CollectBufferFootprintDefInfo(tensor_info, schedule, node); - } else { - if (index == 0) { - CollectBufferFootprintDefInfo(tensor_info, schedule, node); - } else { - isl::id new_dst_id = tensor_info.GetIndexDstId(ctx_, tensor_info.dst_tensor_id, index); - BufferDefInfo new_footprint_info = BufferDefInfo{tensor_info.tensor_id, - new_dst_id, - tensor_info.ancester_tensor_id, - tensor_info.mem_type, - tensor_info.mark_tag, - false, - tensor_info.is_bind_tensor, - tensor_info.MakeDataStream(new_dst_id), - Tensor(), - Handle(), - tensor_info.sizes, - nullptr, - isl::union_map::empty(CreateParamsSpace(ctx_))}; - CollectBufferFootprintDefInfo(new_footprint_info, schedule, node); - buffer_def_infos_.push_back(new_footprint_info); - } - } -} - -void Scop::UpdateSpecGemmFractalInfo(const BufferDefInfo &tensor_info) { - if (IsConv() && IsB(tensor_info.tensor_id.get_name())) { - CHECK(tensor_info.footprints_cluster != nullptr); - UpdateFractalIntLastInfo(tensor_info.footprints_cluster->GetFixedBoxSizes()); - fractal_str_info_[ATTR_CONV_GMM_WEIGHT] = tensor_info.dst_tensor_id.get_name(); - CHECK_NE(tensor_info.dst_tensor_id.get_name(), ""); - } else if (IsConv() && IsA(tensor_info.tensor_id.get_name())) { - fractal_str_info_[ATTR_CONV_GMM_FEATURE] = tensor_info.data_stream[2].first.get_name(); - CHECK_NE(tensor_info.dst_tensor_id.get_name(), ""); - } else if (IsConv() && IsC(tensor_info.tensor_id.get_name())) { - fractal_str_info_[ATTR_CONV_GMM_RES] = tensor_info.dst_tensor_id.get_name(); - CHECK_NE(tensor_info.dst_tensor_id.get_name(), ""); - } -} - -void Scop::MakeBufferFootprintCluster(BufferDefInfo &tensor_info) { - std::vector nodes = CollectMarkNode(schedule_.get_root(), tensor_info.mark_tag); - int index = 0; - for (const auto &node : nodes) { - isl::schedule_node tree = node.get_child(0); - auto schedule = LocalSchedule(tree); - - // get TensorFootPrintsCluster for each tensor - if (tensor_info.IsIm2col()) { - HoistIm2colBufferFootprintCluster(schedule, node, index, tensor_info); - } else { - if (tensor_info.IsGemmDataL12L0() || tensor_info.IsGemmWeightL12L0()) { - AddGemmTransposeFpCluster(schedule); - } - MakeMultiBufferFootprint(schedule, node, index, tensor_info); - UpdateSpecGemmFractalInfo(tensor_info); - } - index++; - } -} - -isl::union_set CollectDomain(const isl::schedule_node &node) { - int depth = node.get_tree_depth(); - isl::schedule_node tmp_node; - isl::union_set domain = node.get_domain(); - for (int i = 0; i < depth; ++i) { - tmp_node = node.ancestor(depth - i); - if (auto filter_node = tmp_node.as()) { - domain = domain.intersect(filter_node.get_filter()); - } - if (auto extension_node = tmp_node.as()) { - auto parent_schedule = ShortSchedule(tmp_node); - auto extension = extension_node.get_extension(); - parent_schedule = parent_schedule.intersect_domain(domain); - domain = domain.unite(parent_schedule.range().apply(extension)); - } - } - return domain; -} - -std::vector Scop::CollectBufferedFootprintsIndexes(const isl::union_set &active_domains, - const isl::id &tensor_id) const { - std::vector result; - - for (size_t i = 0, e = active_buffer_footprints_.size(); i < e; ++i) { - const auto &act_fp = active_buffer_footprints_[i]; - if (act_fp.first.intersect(active_domains).is_empty()) { - continue; - } - - auto cluster_id = act_fp.second.cluster_id; - for (const auto &def_iter : BufferDefInfos()) { - if (def_iter.dst_tensor_id.get_name() == cluster_id.get_name() && - def_iter.tensor_id.get_name() == tensor_id.get_name()) { - result.push_back(i); - break; - } - } - } - return result; -} - -std::vector> Scop::CollectBufferedFootprints( - const isl::union_set &active_domains, const isl::id &tensor_id) const { - std::vector> result; - - for (auto idx : CollectBufferedFootprintsIndexes(active_domains, tensor_id)) { - result.emplace_back(active_buffer_footprints_[idx]); - } - return result; -} - -std::shared_ptr Scop::GetFootPrintsCluster(const isl::id &tensor_id) { - for (const auto &info : buffer_def_infos_) { - if (info.tensor_id.get_name() == tensor_id.get_name()) { - return info.footprints_cluster; - } - } - return nullptr; -} - -isl::schedule_node Scop::HoistTensorClusterFootprint(isl::schedule_node tree, size_t buffered_fp_idx, - const isl::union_map &schedule) { - BufferDefInfo &tensor_info = buffer_def_infos_[buffered_fp_idx]; - - isl::schedule_node mark_node = tree; - if (tree.has_parent()) { - mark_node = tree.parent(); - } - - isl::id src_tensor_id = tensor_info.tensor_id; - isl::id dst_tensor_id = tensor_info.dst_tensor_id; - bool is_bind_tensor = tensor_info.is_bind_tensor; - - auto fp_cluster = tensor_info.GetFootPrintCluster(mark_node); - if ((fp_cluster == nullptr) || (!fp_cluster->foot_print_.box.is_valid())) { - LOG(INFO) << "FootprintsClusters: fp_cluster is null or box is invalid! src: " << src_tensor_id - << ", dst: " << dst_tensor_id; - return tree; - } - - auto active_domains = CollectDomain(tree); - auto active_buf_fp = CollectBufferedFootprints(active_domains, src_tensor_id); - auto foot_prints = isl::set::empty(fp_cluster->GetSingleAccessRange().get_space()); - auto all_read_only = fp_cluster->UnWriteable(); - for (const auto &buf_fp : active_buf_fp) { - foot_prints = foot_prints.unite(buf_fp.second.cluster->GetSingleAccessRange()); - all_read_only = all_read_only && buf_fp.second.cluster->UnWriteable(); - } - - if (is_bind_tensor && tensor_info.mem_type != MemType::UBL0_) { - if (!(IsGemm() && tensor_info.IsCubeCL1Write())) { - bool insert_ub_to_l1 = false; - if (!data_.fake_copyin.is_empty()) { - data_.fake_copyin.foreach_map([&insert_ub_to_l1, &src_tensor_id, &dst_tensor_id](const isl::map &m) -> void { - if ((m.get_tuple_id(isl_dim_out).get_name() == src_tensor_id.get_name()) && - (src_tensor_id.get_name() + "_local_L1" == dst_tensor_id.get_name())) { - insert_ub_to_l1 = true; - } - }); - } - if (insert_ub_to_l1) { - isl::id outer_tensorId = isl::id(src_tensor_id.ctx(), src_tensor_id.get_name() + "_local_UB"); - tree = - PlaceInnerDataCopyBelow(*this, tree, *fp_cluster, *fp_cluster, src_tensor_id, dst_tensor_id, outer_tensorId); - } else { - tree = PlaceOuterDataCopyBelow(*this, tree, *fp_cluster, src_tensor_id, dst_tensor_id); - } - } else { - buffer_footprint_queue_.push(src_tensor_id); - } - // If the new buffer_footprint is not a strict subset of any other parent - auto cluster = std::shared_ptr(std::move(fp_cluster)); - active_buffer_footprints_.emplace_back( - std::make_pair(active_domains, BufferedFootPrintInfo{cluster, schedule, dst_tensor_id})); - tensor_info.find_buffer = true; - return tree; - } - - if (tensor_info.IsIm2col()) { - isl::id cluster_id = tensor_info.NextTensorDstId(); - auto l0_fp_cluster = GetFootPrintsCluster(dst_tensor_id); - CHECK(l0_fp_cluster != nullptr); - tree = PlaceIm2colBelow(*this, tree, *l0_fp_cluster, *fp_cluster, cluster_id, dst_tensor_id); - // If the new buffer_footprint is not a strict subset of any other parent - auto cluster = std::shared_ptr(std::move(l0_fp_cluster)); - active_buffer_footprints_.emplace_back( - std::make_pair(active_domains, BufferedFootPrintInfo{cluster, schedule, dst_tensor_id})); - tensor_info.find_buffer = true; - SetFindBuffer(dst_tensor_id, true); - return tree; - } - - if (tensor_info.IsGemmDataL12L0()) { - if (IsGemmDataTranspose()) { - const isl::id &trans_id = dst_tensor_id; - const isl::id &cluster_id = dst_tensor_id; - tree = PlaceIm2colBelow(*this, tree, *gemm_a_transpose_fp_cluster_, *fp_cluster, trans_id, cluster_id); - active_buffer_footprints_.emplace_back( - std::make_pair(active_domains, BufferedFootPrintInfo{gemm_a_transpose_fp_cluster_, schedule, cluster_id})); - } - } - - if (tensor_info.IsGemmWeightL12L0()) { - if (IsGemmWeightTranspose()) { - const isl::id &trans_id = dst_tensor_id; - const isl::id &cluster_id = dst_tensor_id; - tree = PlaceIm2colBelow(*this, tree, *gemm_b_transpose_fp_cluster_, *fp_cluster, trans_id, cluster_id); - active_buffer_footprints_.emplace_back( - std::make_pair(active_domains, BufferedFootPrintInfo{gemm_b_transpose_fp_cluster_, schedule, cluster_id})); - } - } - auto scop_cluster = fp_cluster; - if (IsGemm() && (tensor_info.IsGemmDataL12L0() || tensor_info.IsGemmWeightL12L0())) { - scop_cluster = GetBufferDefInfo(tensor_info.tensor_id).footprints_cluster; - } - if (tensor_info.IsPreCubeTile2Write()) { - auto info = GetBufferDefInfo(tensor_info.tensor_id); - auto new_scop_group = info.GetFootPrintCluster(mark_node); - if (new_scop_group != nullptr) { - scop_cluster = new_scop_group; - } - } - tree = PlaceInnerDataCopyBelow(*this, tree, *fp_cluster, *scop_cluster, src_tensor_id, dst_tensor_id, src_tensor_id); - if (IsGemm() && !buffer_footprint_queue_.empty() && - buffer_footprint_queue_.front().get_name() == tensor_info.ancester_tensor_id.get_name()) { - tree = PlaceOuterDataCopyBelow(*this, tree, *fp_cluster, tensor_info.ancester_tensor_id, src_tensor_id); - buffer_footprint_queue_.pop(); - } - - // If the new buffer_footprint is not a strict subset of any other parent - auto group = std::shared_ptr(std::move(fp_cluster)); - - active_buffer_footprints_.emplace_back( - std::make_pair(active_domains, BufferedFootPrintInfo{group, schedule, dst_tensor_id})); - tensor_info.find_buffer = true; - return tree; -} - -void Scop::ReorderBufferedDefInfos() { - if (data_.fake_copyin.is_empty()) { - return; - } - - std::unordered_set tensors; - data_.fake_copyin.foreach_map( - [&tensors](const isl::map &m) -> void { tensors.insert(m.get_tuple_id(isl_dim_out).get_name()); }); - - for (size_t index = 1; index < buffer_def_infos_.size(); index++) { - if ((buffer_def_infos_[index].mark_tag == REALIZE_L1) && - (tensors.find(buffer_def_infos_[index].tensor_id.get_name()) != tensors.end())) { - BufferDefInfo promoted_info = buffer_def_infos_[index]; - buffer_def_infos_.erase(buffer_def_infos_.begin() + static_cast(index)); - buffer_def_infos_.insert(buffer_def_infos_.begin(), promoted_info); - } - } -} - -int Scop::CountBufferDefInfo(const isl::id &tensor_id) const { - int num = 0; - for (const auto &tensorIter : BufferDefInfos()) { - if (tensorIter.dst_tensor_id.get_name() == tensor_id.get_name()) { - num++; - } - } - return num; -} - -void Scop::AddGemmTransposeFpCluster(const isl::union_map &schedule) { - auto domain = schedule.domain(); - if (IsGemmDataTranspose()) { - if (IsGemmDataTransposeBlock()) { - gemm_a_transpose_fp_cluster_ = ConstructAffineFpCluster(*this, data_.reads, domain, schedule, ReferenceType::Read, - AffineType::AFFINE_GEMMBLOCK, AffineTensor::LEFT_TENSOR); - } else if (IsGemmDataTransposeInnerBlock()) { - gemm_a_transpose_fp_cluster_ = - ConstructAffineFpCluster(*this, data_.reads, domain, schedule, ReferenceType::Read, - AffineType::AFFINE_GEMMBLOCKIN, AffineTensor::LEFT_TENSOR); - } else { - gemm_a_transpose_fp_cluster_ = ConstructAffineFpCluster(*this, data_.reads, domain, schedule, ReferenceType::Read, - AffineType::AFFINE_GEMM, AffineTensor::LEFT_TENSOR); - } - } - if (IsGemmWeightTranspose()) { - if (IsGemmWeightTransposeBlock()) { - gemm_b_transpose_fp_cluster_ = ConstructAffineFpCluster(*this, data_.reads, domain, schedule, ReferenceType::Read, - AffineType::AFFINE_GEMMBLOCK, AffineTensor::RIGHT_TENSOR); - } else if (IsGemmWeightTransposeInnerBlock()) { - gemm_b_transpose_fp_cluster_ = - ConstructAffineFpCluster(*this, data_.reads, domain, schedule, ReferenceType::Read, - AffineType::AFFINE_GEMMBLOCKIN, AffineTensor::RIGHT_TENSOR); - } else { - gemm_b_transpose_fp_cluster_ = ConstructAffineFpCluster(*this, data_.reads, domain, schedule, ReferenceType::Read, - AffineType::AFFINE_GEMM, AffineTensor::RIGHT_TENSOR); - } - } -} - -void GetAffOffsetAndNumVars(const isl::aff &aff, int &offset, int &num_vars) { - offset = aff.get_constant_val().get_num_si(); - - num_vars = 0; - int dim = isl_aff_dim(aff.get(), isl_dim_in); - CHECK_GE(dim, 0); - for (int j = 0; j < dim; ++j) { - isl_val *coef = isl_aff_get_coefficient_val(aff.get(), isl_dim_in, j); - int coef_val = isl_val_get_num_si(coef); - static_cast(isl_val_free(coef)); - if (coef_val != 0) ++num_vars; - } -} - -/* - * Check the isl::aff is in the form of { [i0, i1, i2, i3, i4] -> [(-64 + i2)] } - * i.e. the mapping is one variable plus a non-zero constant offset. - */ -bool IsAffVarPlusOffset(const isl::aff &aff) { - int offset = 0, num_vars = 0; - GetAffOffsetAndNumVars(aff, offset, num_vars); - return offset != 0 && num_vars == 1; -} - -/* - * Check the isl::aff is in the form of { [i0, i1, i2, i3, i4] -> [(64)] } - * i.e. the mapping is a non-zero constant. - */ -bool IsAffNonZeroConst(const isl::aff &aff) { - int offset = 0, num_vars = 0; - GetAffOffsetAndNumVars(aff, offset, num_vars); - return offset != 0 && num_vars == 0; -} - -static isl::pw_multi_aff ComputeNewBufferFootprint(const std::shared_ptr &fp_cluster, - const isl::pw_multi_aff &buffer_footprint) { - if (!fp_cluster->UnWriteable()) return buffer_footprint; - if (!fp_cluster->foot_print_.is_valid) return buffer_footprint; - unsigned num_dims = fp_cluster->foot_print_.GetBoxDim(); - - isl::pw_multi_aff new_buffer_footprint = buffer_footprint; - for (unsigned dim = 0; dim < num_dims; ++dim) { - isl::aff lower_bound = fp_cluster->foot_print_.GetBoxLowerBound(dim); - isl::pw_aff dim_buf_fp = buffer_footprint.get_pw_aff(dim); - if (dim_buf_fp.n_piece() != 1) return buffer_footprint; - // there is only one piece, but we have to use the foreach API - dim_buf_fp.foreach_piece([&lower_bound, &new_buffer_footprint, &dim](const isl::set &set, - const isl::aff &aff) -> void { - if (IsAffVarPlusOffset(lower_bound) && IsAffNonZeroConst(aff)) { - isl::pw_aff zero = isl::pw_aff(isl::manage(isl_aff_set_constant_si(aff.copy(), 0))); - new_buffer_footprint = isl::manage(isl_pw_multi_aff_set_pw_aff(new_buffer_footprint.copy(), dim, zero.copy())); - } - }); - } - return new_buffer_footprint; -} - -/* - * Remove the constant offset from provide args, e.g. input_1_local_UB(32, 7, cc2, cc3) = input_1(...) - * Check the footprint cluster of the hoisted var to confirm this input tensor has multiple accesses - * from shifted tiles. This should be improved by computing the new footprint with footprint_per_access(), - * but from isl AST we do not know the footprint ID that corresponds to the GM -> UB copy. - */ -isl::pw_multi_aff Scop::RemoveConstOffsetFromBufferFootprint(const isl::pw_multi_aff &buffer_footprint) { - const isl::id buffer_id = buffer_footprint.get_tuple_id(isl_dim_out); - for (const auto &act_buf : ActiveBufferFootprints()) { - if (act_buf.second.cluster_id == buffer_id) { - const auto &footprint_cluster = act_buf.second.cluster; - return ComputeNewBufferFootprint(footprint_cluster, buffer_footprint); - } - } - return buffer_footprint; -} - -bool Scop::HasBufferDefInfo(const isl::id &tensor_id) const { - for (const auto &idx : BufferDefInfos()) { - if (idx.dst_tensor_id.get_name() == tensor_id.get_name()) { - return true; - } - } - return false; -} - -void Scop::UpdateFractalIntFirstInfo(bool is_conv_backprop_filter, const std::vector &im2col_fp_cluster_size, - const std::vector &fractal_fp_cluster_size) { - if (is_conv_backprop_filter) { - UpdateFractalIntFirstInfoConvBackpropFilter(im2col_fp_cluster_size, fractal_fp_cluster_size); - } else { - UpdateFractalIntFirstInfoConvForward(im2col_fp_cluster_size, fractal_fp_cluster_size); - } -} - -void Scop::UpdateFractalIntLastInfo(std::vector filter_fp_cluster_size) { - if (IsConvBackpropInput()) { - CHECK_EQ(filter_fp_cluster_size.size(), 4); - // conv_backprop_input filter: [ko, no, ni, ki] - int64_t kh = ExtractIntFromAttrs(ATTR_CONV_KERNEL_H); - int64_t kw = ExtractIntFromAttrs(ATTR_CONV_KERNEL_W); - fractal_int_info_[ATTR_CONV_TILE_CO] = (int64_t)filter_fp_cluster_size[0] / (kh * kw); - fractal_int_info_[ATTR_CONV_TILE_N] = (int64_t)filter_fp_cluster_size[0] / (kh * kw); - - fractal_int_info_[ATTR_CONV_N_INNER] = (int64_t)filter_fp_cluster_size[2]; - } else if (IsConvBackpropFilter()) { - CHECK_EQ(filter_fp_cluster_size.size(), 5); - // conv_backprop_filter filter: [batch, no, mo, ni, mi] - fractal_int_info_[ATTR_CONV_TILE_M] = (int64_t)filter_fp_cluster_size[1]; - fractal_int_info_[ATTR_CONV_M_INNER] = (int64_t)filter_fp_cluster_size[3]; - fractal_int_info_[ATTR_CONV_GMM_M] = (int64_t)filter_fp_cluster_size[1] * filter_fp_cluster_size[3]; - } else { - CHECK_EQ(filter_fp_cluster_size.size(), 4); - // conv_forward filter: [ko, no, ni, ki] - fractal_int_info_[ATTR_CONV_TILE_CO] = (int64_t)filter_fp_cluster_size[1]; - fractal_int_info_[ATTR_CONV_TILE_N] = (int64_t)filter_fp_cluster_size[1]; - fractal_int_info_[ATTR_CONV_N_INNER] = (int64_t)filter_fp_cluster_size[2]; - } -} - -// set the findPromote to the given tensor_id in buffered_decl_infos_ -// based on tensor_id_ -void Scop::SetFindBuffer(const isl::id &tensor_id, bool find_buffer) { - for (auto &info : buffer_def_infos_) { - if (info.tensor_id.get_name() == tensor_id.get_name()) { - info.find_buffer = find_buffer; - return; - } - } - LOG(FATAL) << "hosited tensor" << tensor_id << "has no declaration"; -} - -void Scop::UpdateFractalIntFirstInfoConvBackpropFilter(std::vector im2col_fp_cluster_size, - std::vector fractal_fp_cluster_size) { - CHECK_EQ(fractal_fp_cluster_size.size(), 5); - fractal_int_info_[ATTR_CONV_BATCH] = (int64_t)fractal_fp_cluster_size[0]; - fractal_int_info_[ATTR_CONV_TILE_K] = (int64_t)fractal_fp_cluster_size[1]; - fractal_int_info_[ATTR_CONV_TILE_N] = (int64_t)fractal_fp_cluster_size[2]; - fractal_int_info_[ATTR_CONV_N_INNER] = (int64_t)fractal_fp_cluster_size[3]; - fractal_int_info_[ATTR_CONV_K_INNER] = (int64_t)fractal_fp_cluster_size[4]; - - fractal_int_info_[ATTR_CONV_TILE_CO] = (int64_t)fractal_fp_cluster_size[2]; - - CHECK_EQ(im2col_fp_cluster_size.size(), 6); - fractal_int_info_[ATTR_CONV_GMM_K] = (int64_t)im2col_fp_cluster_size[1]; -} - -void Scop::UpdateFractalIntFirstInfoConvForward(std::vector im2col_fp_cluster_size, - std::vector fractal_fp_cluster_size) { - CHECK_EQ(fractal_fp_cluster_size.size(), 5); - fractal_int_info_[ATTR_CONV_BATCH] = (int64_t)fractal_fp_cluster_size[0]; - fractal_int_info_[ATTR_CONV_TILE_M] = (int64_t)fractal_fp_cluster_size[1]; - fractal_int_info_[ATTR_CONV_TILE_K] = (int64_t)fractal_fp_cluster_size[2]; - fractal_int_info_[ATTR_CONV_M_INNER] = (int64_t)fractal_fp_cluster_size[3]; - fractal_int_info_[ATTR_CONV_K_INNER] = (int64_t)fractal_fp_cluster_size[4]; - - CHECK_EQ(im2col_fp_cluster_size.size(), 6); - fractal_int_info_[ATTR_CONV_GMM_M] = (int64_t)im2col_fp_cluster_size[1]; -} - -isl::union_map LocalScheduleImpl(const isl::schedule_node &node, bool use_node) { - int tree_depth = node.get_tree_depth(); - int new_tree_depth = tree_depth; - if (use_node) ++new_tree_depth; - isl::schedule_node tmp_node; - isl::union_map schedule = isl::union_map::from_domain(node.get_domain()); - for (int i = 0; i < new_tree_depth; ++i) { - tmp_node = node.ancestor(tree_depth - i); - if (auto band_node = tmp_node.as()) { - if (band_node.n_member() > 0) { - schedule = schedule.flat_range_product(band_node.get_partial_schedule_union_map()); - } - } else if (auto filter_node = tmp_node.as()) { - schedule = schedule.intersect_domain(filter_node.get_filter()); - } else if (auto extension_node = tmp_node.as()) { - schedule = schedule.unite(extension_node.get_extension().reverse().intersect_range(schedule.range())); - } - } - return schedule; -} - -isl::union_map ShortSchedule(const isl::schedule_node &node) { return LocalScheduleImpl(node, false); } - -isl::union_map LocalSchedule(const isl::schedule_node &node) { return LocalScheduleImpl(node, true); } - -void Scop::GatherFractalDefInfo(const isl::schedule_node &tree, BufferDefInfo &tensor_info, - std::vector &sizes) { - isl::id tensor_id = tensor_info.tensor_id; - isl::id cluster_id = tensor_info.dst_tensor_id; - - Array shapes; - for (auto i : sizes) { - shapes.push_back(Expr(static_cast(i))); - } - - Type type = GetDtypeOf(tensor_id); - Tensor tensor = placeholder(shapes, type, cluster_id.get_name()); - const Buffer buffer = decl_buffer(shapes, GetDtypeOf(tensor_id), cluster_id.get_name()); - binds_.Set(tensor, buffer); - - tensor_info.sizes = sizes; - tensor_info.tensor = tensor; - tensor_info.data_type = type; - tensor_info.AddSize(tree, sizes); -} - -/* - * Update sizes of a specific tensor in order to support realize shape expansion in UB -> L1 strided copy - * param new_sizes: new shape of the tensor - * return: found or not found - */ -bool Scop::UpdateBufferDefInfoSizes(const isl::id &tensor_id, const std::vector &new_sizes) { - for (auto &info : buffer_def_infos_) { - // update the first occurrence - if (info.dst_tensor_id == tensor_id) { - auto old_sizes = info.sizes; - CHECK(old_sizes.size() == new_sizes.size()); - Array shapes; - for (size_t dim = 0; dim < new_sizes.size(); ++dim) { - size_t new_size = std::max(new_sizes[dim], old_sizes[dim]); - shapes.push_back(Expr(static_cast(new_size))); - } - Tensor tensor = placeholder(shapes, info.data_type, tensor_id.get_name()); - const Buffer buffer = decl_buffer(shapes, info.data_type, tensor_id.get_name()); - binds_.Set(tensor, buffer); - - info.sizes = new_sizes; - info.tensor = tensor; - return true; - } - } - return false; -} - -} // namespace poly -} // namespace ir -} // namespace akg diff --git a/src/poly/reschedule.h b/src/poly/pass_info.cc similarity index 67% rename from src/poly/reschedule.h rename to src/poly/pass_info.cc index c64205285f602fe66e075ceaf606d4f58db670ca..cb417a0ffaa8cccace701f0f0ca11a28e19cee3b 100644 --- a/src/poly/reschedule.h +++ b/src/poly/pass_info.cc @@ -13,21 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef POLY_RESCHEDULE_H_ -#define POLY_RESCHEDULE_H_ +#include "poly/pass_info.h" -#pragma once -#include "poly/transform.h" +#include +#include +#include + +#include +#include +#include +#include namespace akg { namespace ir { -namespace poly { - -isl::schedule_node ReorderFilters(const isl::schedule_node &node, - const std::unordered_map &old_to_new_map); - -} // namespace poly +namespace poly {} // namespace poly } // namespace ir } // namespace akg - -#endif // POLY_RESCHEDULE_H_ diff --git a/src/poly/pass_info.h b/src/poly/pass_info.h new file mode 100644 index 0000000000000000000000000000000000000000..42ce3f7358ef55e58427fd416079a8ecbd877511 --- /dev/null +++ b/src/poly/pass_info.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_PASS_INFO_H_ +#define POLY_PASS_INFO_H_ + +#include +#include +#include +#include "isl.h" + +namespace akg { +namespace ir { +namespace poly { +using ReduceStmtMap = std::unordered_map, isl::IslIdIslHash>; +class Dependency { + private: + isl::id start_node_id_; + isl::id end_node_id_; + int64_t edge_weight_; + + public: + Dependency(const isl::id start_node_id, const isl::id end_node_id, const int64_t edge_weight) + : start_node_id_(start_node_id), end_node_id_(end_node_id), edge_weight_(edge_weight) {} + ~Dependency() {} + + isl::id GetStartNode() { return start_node_id_; } + isl::id GetEndNode() { return end_node_id_; } + int64_t GetEdgeWeight() const { return edge_weight_; } +}; + +// pass info on schedule transform +class PassInfo { + public: + PassInfo() {} + ~PassInfo() {} + + bool has_grouped_{false}; + bool tile_check_coincident_{false}; + + std::unordered_map group_filter_map_; + + std::vector dependency_list_; + + isl::union_pw_multi_aff group_upma_; + + isl::schedule_constraints constraints_; + + bool coincident_{true}; + + isl::union_map dependences_; + + isl::union_map orig_dependences_; + isl::union_set transfer_stmt_; + ReduceStmtMap reduce_stmts_; + std::map invariant_state_; + bool has_invariant_dependence_{false}; + + bool restart_{false}; + +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_PASS_INFO_H_ diff --git a/src/poly/pass_mgr_strategy.h b/src/poly/pass_mgr_strategy.h new file mode 100644 index 0000000000000000000000000000000000000000..7758b2b0f59d9b89d1bf0b675bcd0c9012b5b054 --- /dev/null +++ b/src/poly/pass_mgr_strategy.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POLY_PASS_MGR_STRATEGY_H_ +#define POLY_PASS_MGR_STRATEGY_H_ + +#include "poly/schedule_pass.h" +#include "poly/dump_log.h" + +#include "poly/schedule_pass/init_schedule.h" +#include "poly/schedule_pass/compute_schedule.h" + +namespace akg { +namespace ir { +namespace poly { + +class PassMgrStrategy { + public: + explicit PassMgrStrategy(ScopInfo &scop_info) : scop_info_(scop_info) {} + + void RegisterPass(std::shared_ptr pass) { + CHECK(pass); + passes_.emplace_back(std::move(pass)); + } + void RegisterNormalizationPasses() { RegisterPass(std::make_shared(pass_info_, scop_info_)); } + void RegisterSchedulingPasses() { RegisterPass(std::make_shared(pass_info_, scop_info_)); } + virtual void RegisterTilingPasses() = 0; // each backend has different achievement + virtual void RegisterMemPromPasses() = 0; // each backend has different achievement + virtual void RegisterPasses() = 0; + const std::vector> &GetPasses() const { return passes_; }; + + virtual ~PassMgrStrategy() = default; + + ScopInfo &scop_info_; + PassInfo pass_info_; + + protected: + std::vector> passes_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg +#endif // POLY_PASS_MGR_STRATEGY_H_ \ No newline at end of file diff --git a/src/poly/poly.cc b/src/poly/poly.cc index 613773fa84e77acc2248aa9d9a64c5489a14699b..449b05b8a74413ec02f998f89cd713aeba34e785 100644 --- a/src/poly/poly.cc +++ b/src/poly/poly.cc @@ -13,15 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include - -#include "ir_pass.h" #include "poly/scop.h" -#include "pass/utils.h" - namespace akg { namespace ir { /*! @@ -31,63 +24,72 @@ class Poly { public: Poly() : isl_ctx_(isl::ctx(isl_ctx_alloc())) {} + ~Poly() noexcept { + scop_.reset(); + // scop must be deconstructed before isl_ctx is deconstructed + isl_ctx_free(isl_ctx_.get()); + } + void Run(const Stmt &stmt, const Map &extern_buffer, const Map &attrs, const bool is_spec_gemm, bool is_tuning, bool is_dynamic) { stmt_ = stmt; - scop_.reset(new poly::Scop(Simplify_cce(stmt_), extern_buffer, isl_ctx_, is_spec_gemm)); + scop_.reset(new poly::Scop(Simplify_cce(stmt_), isl_ctx_)); CHECK(scop_ != nullptr); + scop_->ParseUserConfig(attrs, extern_buffer, is_spec_gemm, is_tuning, is_dynamic); - scop_->SetAttrs(attrs); - scop_->is_dynamic_ = is_dynamic; - - // generate isl schedule from Halide std::chrono::high_resolution_clock::time_point timer_start; + // generate isl schedule from Halide TIMER_START; isl::schedule sch = scop_->GenIsl(); TIMER_SHOW("GenIsl", std::string(is_spec_gemm ? "_specgemm" : "")); - // transform isl schedule with coincidence constraints - isl::schedule scht = scop_->Transform(sch, true, is_tuning); - if (is_tuning) return; - - if (scht.get() == sch.get()) { - // transform failed, redo transform without coincidence constraints - scht = scop_->Transform(sch, false); - } + // isl schedule transform + TIMER_START; + isl::schedule sched = scop_->Transform(sch); + TIMER_SHOW("Transform", std::string(is_spec_gemm ? "_specgemm" : "")); // generate Halide from isl schedule - stmt_ = scop_->GenHalide(scht); + TIMER_START; + stmt_ = scop_->GenHalide(sched); + TIMER_SHOW("GenHalide", std::string(is_spec_gemm ? "_specgemm" : "")); + + if (is_dynamic) stmt_ = RestoreCombinedParams(stmt_, scop_->info_); + + if (is_tuning) { + spaces_ = GenerateTilingSpace(sched, scop_->info_, stmt_, scop_->info_.user_config_.GetDumpTuningLevel()); + return; + } // optimize post poly Halide IR for Davinci - if (scop_->enable_feature_library_ || scop_->optimize_for_davinci_) { - stmt_ = poly::OptimizeHalide(stmt_, !scop_->params_.empty()); + if (scop_->info_.user_config_.GetEnableFeatureLib() || scop_->info_.user_config_.GetOptimizeForDavinci()) { + stmt_ = poly::DavinciHalideOptimizer(stmt_, !scop_->info_.user_config_.GetParams().empty()); } - gen_empty_tiling = scop_->is_tiled_; + gen_empty_tiling = scop_->info_.analysis_result_.GetIsTiled(); } - ~Poly() noexcept { - scop_.reset(); - // scop must be deconstructed before isl_ctx is deconstructed - isl_ctx_free(isl_ctx_.get()); - } + Stmt GetStmt() { return stmt_; } - Stmt getstmt() { return stmt_; } - bool gen_empty_tiling{false}; - Array getTilingParams() { + NodeRef GetSpaces() { return spaces_; } + + Array GetTilingParams() { CHECK(scop_ != nullptr); Array tiling_params_array; if (gen_empty_tiling) return tiling_params_array; std::unordered_set tiling_params; - for (const auto &kv : scop_->param_tiling_map_) { + auto param_tiling_map = scop_->info_.user_config_.GetParamTilingMap(); + for (const auto &kv : param_tiling_map) { GatherVars(kv.second, &tiling_params); } for (const auto ¶m : tiling_params) tiling_params_array.push_back(param); return tiling_params_array; } - NodeRef getspaces() { - CHECK(scop_ != nullptr); - return scop_->spaces_; + void GatherVars(const Expr expr, std::unordered_set *vset) { + PostOrderVisit(expr, [&vset](const NodeRef &node) { + if (node.as()) { + vset->insert(Downcast(node)); + } + }); } private: @@ -96,6 +98,8 @@ class Poly { // and we need to ensure that they are deconstructed before the isl_ctx is freed. isl::ctx isl_ctx_; Stmt stmt_; + NodeRef spaces_; + bool gen_empty_tiling{false}; }; /// Interface for lower pass @@ -103,14 +107,14 @@ Array AutoPoly(const Stmt &stmt, const Map &extern_buff const Map &attrs, const bool is_specgemm, const bool is_dynamic) { Poly poly; poly.Run(stmt, extern_buffer, attrs, is_specgemm, false, is_dynamic); - return Array({poly.getstmt(), poly.getTilingParams()}); + return Array({poly.GetStmt(), poly.GetTilingParams()}); } NodeRef GenTuningSpace(const Stmt &stmt, const Map &extern_buffer, const Map &attrs, const bool is_specgemm) { Poly poly; poly.Run(stmt, extern_buffer, attrs, is_specgemm, true, false); - return poly.getspaces(); + return poly.GetSpaces(); } } // namespace ir } // namespace akg diff --git a/src/poly/poly_util.cc b/src/poly/poly_util.cc index 6753159d01c190c107ea8ec28ccf1d39fb2d1389..7b615581318284b558a4ff15a793a87c1bfa030b 100644 --- a/src/poly/poly_util.cc +++ b/src/poly/poly_util.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "poly/poly_util.h" namespace akg { @@ -120,6 +121,65 @@ Stmt PeelOuterLetStmt(const Stmt &s, std::vector &outer_stmts) { return body; } +void GetAffOffsetAndNumVars(const isl::aff &aff, int &offset, int &num_vars) { + offset = aff.get_constant_val().get_num_si(); + + num_vars = 0; + int dim = isl_aff_dim(aff.get(), isl_dim_in); + CHECK_GE(dim, 0); + for (int j = 0; j < dim; ++j) { + isl_val *coef = isl_aff_get_coefficient_val(aff.get(), isl_dim_in, j); + int coef_val = isl_val_get_num_si(coef); + static_cast(isl_val_free(coef)); + if (coef_val != 0) ++num_vars; + } +} + +/* + * Check the isl::aff is in the form of { [i0, i1, i2, i3, i4] -> [(-64 + i2)] } + * i.e. the mapping is one variable plus a non-zero constant offset. + */ +bool IsAffVarPlusOffset(const isl::aff &aff) { + int offset = 0, num_vars = 0; + GetAffOffsetAndNumVars(aff, offset, num_vars); + return offset != 0 && num_vars == 1; +} + +/* + * Check the isl::aff is in the form of { [i0, i1, i2, i3, i4] -> [(64)] } + * i.e. the mapping is a non-zero constant. + */ +bool IsAffNonZeroConst(const isl::aff &aff) { + int offset = 0, num_vars = 0; + GetAffOffsetAndNumVars(aff, offset, num_vars); + return offset != 0 && num_vars == 0; +} + +isl::union_map LocalScheduleImpl(const isl::schedule_node &node, bool use_node) { + int tree_depth = node.get_tree_depth(); + int new_tree_depth = tree_depth; + if (use_node) ++new_tree_depth; + isl::schedule_node tmp_node; + isl::union_map schedule = isl::union_map::from_domain(node.get_domain()); + for (int i = 0; i < new_tree_depth; ++i) { + tmp_node = node.ancestor(tree_depth - i); + if (auto band_node = tmp_node.as()) { + if (band_node.n_member() > 0) { + schedule = schedule.flat_range_product(band_node.get_partial_schedule_union_map()); + } + } else if (auto filter_node = tmp_node.as()) { + schedule = schedule.intersect_domain(filter_node.get_filter()); + } else if (auto extension_node = tmp_node.as()) { + schedule = schedule.unite(extension_node.get_extension().reverse().intersect_range(schedule.range())); + } + } + return schedule; +} + +isl::union_map ShortSchedule(const isl::schedule_node &node) { return LocalScheduleImpl(node, false); } + +isl::union_map LocalSchedule(const isl::schedule_node &node) { return LocalScheduleImpl(node, true); } + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/poly_util.h b/src/poly/poly_util.h index c098c30261d768492509dfb7f1f0dd62dd02f8fb..c5135a1be7aab015f4f5b9a9cea2d08ee4418336 100644 --- a/src/poly/poly_util.h +++ b/src/poly/poly_util.h @@ -15,12 +15,9 @@ */ #ifndef POLY_UTIL_H_ #define POLY_UTIL_H_ -#pragma once #include -#include #include -#include -#include +#include #include "isl.h" namespace akg { @@ -31,28 +28,26 @@ namespace poly { #define PRETTY_PRINT_IR true #define DUMP_SCOP_DATA true #define DUMP_SCOP_DATA_PER_PASS false -#define DUMP_TRANSFORM true -#define DUMP_TRANSFORM_PER_PASS false #define DUMP_IN_CURRENT_DIR false -#define PRINT_C false #define PRINT_SCHEDULE_INFO false +#define PRINT_ISL_EMMITER false +#define PRINT_CCE_ISL_EMMITER false +#define PRINT_EMMITER (PRINT_ISL_EMMITER || PRINT_CCE_ISL_EMMITER) #define SPEC_GEMM true #define DELETE_FRACTAL true /// conv_backward options #define SELECT_DOMAIN_OPT true -/// transform options -#define USE_CACHED_SCHEDULE false -#define ENABLE_REPLACE_SCHEDULE_HOOK true - -/// constants -constexpr auto kReadSuffix = "read"; -constexpr auto kWriteSuffix = "write"; -constexpr auto kIterNamePrefix = "cc"; -constexpr auto kGemmIterNamePrefix = "ee"; -constexpr auto TENSORLISTTAILNAME = "TensorListTail"; +// timer records +#define TIMER_START timer_start = std::chrono::high_resolution_clock::now() +#define TIMER_DURATION \ + (std::chrono::duration_cast>(std::chrono::high_resolution_clock::now() - timer_start) \ + .count()) * \ + 1000 +#define TIMER_SHOW(NAME, SPEC_GEMM) \ + { LOG(INFO) << "[ Polyhedral exec time" << SPEC_GEMM << " ], " << NAME << " spent " << TIMER_DURATION << " ms"; } unsigned int WrappedStrtol(const std::string &str); @@ -68,6 +63,12 @@ Expr RemoveCast(Expr e); Stmt PeelOuterLetStmt(const Stmt &s, std::vector &outer_stmts); +isl::union_map ShortSchedule(const isl::schedule_node &node); +isl::union_map LocalSchedule(const isl::schedule_node &node); +void GetAffOffsetAndNumVars(const isl::aff &aff, int &offset, int &num_vars); +bool IsAffVarPlusOffset(const isl::aff &aff); +bool IsAffNonZeroConst(const isl::aff &aff); + class ConsolidateExprMutator : public IRMutator { public: explicit ConsolidateExprMutator(const std::unordered_map ¶ms_) : params(params_) {} @@ -86,15 +87,15 @@ class ConsolidateExprMutator : public IRMutator { } // list operators that may appear in dynamic shape params - Expr Mutate_(const Add *op, const Expr &e) { return GenericMutate(op, e); } - Expr Mutate_(const Sub *op, const Expr &e) { return GenericMutate(op, e); } - Expr Mutate_(const Mul *op, const Expr &e) { return GenericMutate(op, e); } - Expr Mutate_(const FloorDiv *op, const Expr &e) { return GenericMutate(op, e); } - Expr Mutate_(const FloorMod *op, const Expr &e) { return GenericMutate(op, e); } - Expr Mutate_(const Div *op, const Expr &e) { return GenericMutate(op, e); } - Expr Mutate_(const Mod *op, const Expr &e) { return GenericMutate(op, e); } - Expr Mutate_(const Min *op, const Expr &e) { return GenericMutate(op, e); } - Expr Mutate_(const Max *op, const Expr &e) { return GenericMutate(op, e); } + Expr Mutate_(const Add *op, const Expr &e) override { return GenericMutate(op, e); } + Expr Mutate_(const Sub *op, const Expr &e) override { return GenericMutate(op, e); } + Expr Mutate_(const Mul *op, const Expr &e) override { return GenericMutate(op, e); } + Expr Mutate_(const FloorDiv *op, const Expr &e) override { return GenericMutate(op, e); } + Expr Mutate_(const FloorMod *op, const Expr &e) override { return GenericMutate(op, e); } + Expr Mutate_(const Div *op, const Expr &e) override { return GenericMutate(op, e); } + Expr Mutate_(const Mod *op, const Expr &e) override { return GenericMutate(op, e); } + Expr Mutate_(const Min *op, const Expr &e) override { return GenericMutate(op, e); } + Expr Mutate_(const Max *op, const Expr &e) override { return GenericMutate(op, e); } const std::unordered_map ¶ms; }; @@ -168,6 +169,9 @@ constexpr auto ATTR_GEMM_WEIGHT_TRANSPOSE_BLOCK_INNER = "pragma_weight_transpose constexpr auto ATTR_ATOMIC_ADD = "atomic_add"; constexpr auto ATOMIC_COND_CLEAN = "atomic_cond_clean"; + +constexpr auto UBL0 = "UBL0"; +constexpr auto REALIZE_ = "realize_"; /****************************************************** * Following const is the mark tags for schedule tree ******************************************************/ diff --git a/src/poly/schedule_pass.cc b/src/poly/schedule_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..337f1f5f610316aa6401e950e96122e9f4c25021 --- /dev/null +++ b/src/poly/schedule_pass.cc @@ -0,0 +1,247 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "schedule_pass.h" + +#include +#include + +namespace akg { +namespace ir { +namespace poly { +/* Reorder filters of a sequence/set node. + * node: must be a sequence or set node. + * old_to_new_map: map from original child position to new child position. + * The caller should make sure that there are no duplicate values. + */ +isl::schedule_node ReorderFilters(const isl::schedule_node &node, + const std::unordered_map &old_to_new_map) { + auto n_children = node.n_children(); + isl_schedule_tree *old_tree = isl_schedule_node_get_tree(node.get()); + CHECK(old_tree != nullptr); + isl_schedule_tree *new_tree = isl_schedule_node_get_tree(node.get()); + CHECK(new_tree != nullptr); + for (auto &it : old_to_new_map) { + auto old_pos = it.first; + auto new_pos = it.second; + CHECK(old_pos < n_children); + CHECK(new_pos < n_children); + isl_schedule_tree *old_child = isl_schedule_tree_get_child(old_tree, old_pos); + CHECK(old_child != nullptr); + new_tree = isl_schedule_tree_replace_child(new_tree, new_pos, old_child); + CHECK(new_tree != nullptr); + } + static_cast(isl_schedule_tree_free(old_tree)); + isl_schedule_node *new_node = isl_schedule_node_graft_tree(node.copy(), new_tree); + CHECK(new_node != nullptr); + return isl::manage(new_node); +} + +isl::union_map DependenceAnalysis(const isl::union_map &sources, const isl::union_map &targets, + const isl::union_map &kills, const isl::union_map &sch) { + auto access_info = isl::union_access_info(targets); + access_info = access_info.set_kill(kills); + access_info = access_info.set_may_source(sources); + access_info = access_info.set_schedule_map(sch); + auto union_flow = access_info.compute_flow(); + return union_flow.get_may_dependence(); +} + +isl::union_map ComputeAllDependences(const isl::schedule &schedule, const isl::union_map &reads_um, + const isl::union_map &writes_um) { + auto reads = reads_um.domain_factor_domain(); + auto writes = writes_um.domain_factor_domain(); + auto sch = schedule.get_map(); + + // RAW + auto flowDeps = DependenceAnalysis(writes, reads, writes, sch); + + // WAR and WAW + auto falseDeps = DependenceAnalysis(writes.unite(reads), writes, writes, sch); + + return flowDeps.unite(falseDeps).coalesce(); +} + +isl::schedule_node GetOuterBand(const isl::schedule_node &root) { + auto outer_band = root; + + while (!outer_band.isa()) { + auto n = outer_band.n_children(); + if (n == 1) { + outer_band = outer_band.child(0); + continue; + } else { + /* + * return the node when encountered branching or a leaf + * an empty band would be inserted elsewhere + */ + return outer_band; + } + } + + return outer_band; +} + +bool IsSequenceOrSet(const isl::schedule_node &node) { + if (node.isa()) return true; + return node.isa(); +} + +isl::union_map ComputeFilterCopyin(const isl::schedule_node &node, const isl::union_map &ori_reads, + const isl::union_map &ori_writes, const isl::schedule ori_schedule) { + CHECK(node.isa()) << "The input should be a filter node!" << std::endl; + + auto filter = node.as().get_filter(); + auto reads = ori_reads.domain_factor_domain().intersect_domain(filter); + auto writes = ori_writes.domain_factor_domain().intersect_domain(filter); + auto uai = isl::union_access_info(reads); + uai = uai.set_kill(writes); + uai = uai.set_may_source(writes); + uai = uai.set_schedule(ori_schedule); + auto flow = uai.compute_flow(); + auto mayNoSource = flow.get_may_no_source(); + auto copyin = ori_reads.intersect_range(mayNoSource.range()); + + return copyin; +} + +isl::union_map ComputeFakeCopyin(const isl::schedule &schedule, const isl::union_map &fake_copyin, + const isl::union_map &ori_reads, const isl::union_map &ori_writes) { + auto root = schedule.get_root(); + auto node = GetOuterBand(root); + auto result = fake_copyin; + + if (!IsSequenceOrSet(node)) return result; + + auto n = node.n_children(); + for (auto i = 0u; i < n; ++i) { + auto child = node.child(i); + auto copyin = ComputeFilterCopyin(child, ori_reads, ori_writes, schedule); + result = result.unite(copyin); + } + + return result; +} + +isl::schedule_constraints MakeScheduleConstraints(const isl::schedule &schedule, PassInfo &pass_info) { + isl::schedule_constraints constraints; + if (pass_info.coincident_) { + constraints = isl::schedule_constraints::on_domain(schedule.get_domain()) + .set_coincidence(pass_info.dependences_) // keep it, check for more cases + .set_validity(pass_info.dependences_) + .set_proximity(pass_info.dependences_); + } else { + constraints = isl::schedule_constraints::on_domain(schedule.get_domain()) + .set_validity(pass_info.dependences_) + .set_proximity(pass_info.dependences_); + } + return constraints; +} + +/* + * Merge multiple lines of strings into a single-line string + */ +static std::string UndoPrettyPrintSchTree(const std::string &schedule) { + const char *src = schedule.c_str(); + std::stringstream dst; + bool in_string = false; + while (*src != '\0') { + if (*src == '"') { + in_string = !in_string; + if (!in_string) { + // end of string, find next non-empty char + const char *next = src + 1; + while (*next != '\0') { + char c = *next; + if (c != ' ' && c != '\t' && c != '\n' && c != '\r') { + break; + } + ++next; + } + if (*next == '"') { + // multiple consecutive strings, merge them and insert a white space + dst << " "; + src = next + 1; + in_string = true; + continue; + } + } + } + dst << *src++; + } + return dst.str(); +} + +bool LoadScheduleTreeFromFile(const std::string &filename, isl::schedule &schedule) { + std::ifstream new_schedule_file_stream(filename); + std::string schedule_to_replace_str((std::istreambuf_iterator(new_schedule_file_stream)), + std::istreambuf_iterator()); + schedule_to_replace_str = UndoPrettyPrintSchTree(schedule_to_replace_str); + isl_schedule *ss = isl_schedule_read_from_str(schedule.ctx().get(), schedule_to_replace_str.c_str()); + if (ss != nullptr) { + schedule = isl::manage(ss); + return true; + } else { + LOG(WARNING) << "Failed to load file " << filename << " to schedule tree, please check syntax of the new schedule."; + return false; + } +} + +/* + * Compare and replace schedule hook: + * Enable users to replace a specific schedule for debugging purpose. + * If the current schedule is identical to the schedule in OLD_SCHEDULE_FILE, + * the schedule will be replaced with NEW_SCHEDULE_FILE. + */ +bool ReplaceScheduleTree(isl::schedule &schedule, ScopInfo &info) { + const std::string OLD_SCHEDULE_FILE = info.AddDumpDir("old_schedule.txt"); + const std::string NEW_SCHEDULE_FILE = info.AddDumpDir("new_schedule.txt"); + // check if two files exist + char pathBuffOld[PATH_MAX + 1] = {0}; + char pathBuffNew[PATH_MAX + 1] = {0}; + bool should_compare_and_replace = false; + if (realpath(OLD_SCHEDULE_FILE.c_str(), pathBuffOld) && realpath(NEW_SCHEDULE_FILE.c_str(), pathBuffNew)) { + FILE *schedule_to_compare = fopen(pathBuffOld, "r"); + FILE *schedule_to_replace = fopen(pathBuffNew, "r"); + should_compare_and_replace = (schedule_to_compare != nullptr && schedule_to_replace != nullptr); + if (schedule_to_compare != nullptr) { + int status = fclose(schedule_to_compare); + if (status != 0) LOG(WARNING) << "Failed to close old_schedule.txt"; + } + if (schedule_to_replace != nullptr) { + int status = fclose(schedule_to_replace); + if (status != 0) LOG(WARNING) << "Failed to close new_schedule.txt"; + } + } + + if (should_compare_and_replace) { + std::ifstream old_schedule_file_stream(OLD_SCHEDULE_FILE); + std::string schedule_to_compare_str((std::istreambuf_iterator(old_schedule_file_stream)), + std::istreambuf_iterator()); + if (CompareSchTreeWithString(schedule_to_compare_str, schedule)) { + LOG(INFO) << "Current schedule is same as " << OLD_SCHEDULE_FILE << ", replace it with new schedule " + << NEW_SCHEDULE_FILE; + CHECK(LoadScheduleTreeFromFile(NEW_SCHEDULE_FILE, schedule)); + return true; + } else { + LOG(INFO) << "Current schedule is different from " << OLD_SCHEDULE_FILE << ", not replacing."; + } + } + return false; +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass.h b/src/poly/schedule_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..07ba3b31cb00a53655485f3774cba88cfe2e5edc --- /dev/null +++ b/src/poly/schedule_pass.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_PASS_H_ +#define POLY_PASS_H_ + +#include "poly/isl.h" +#include "poly/scop_info.h" +#include "poly/pass_info.h" + +namespace akg { +namespace ir { +namespace poly { +class SchedulePass { + public: + virtual ~SchedulePass() {} + virtual isl::schedule Run(isl::schedule sch) = 0; + + std::string GetPassName() { return pass_name_; } + std::string pass_name_; + bool restart_{false}; // triggers restart during runtime +}; + +isl::schedule_node ReorderFilters(const isl::schedule_node &node, + const std::unordered_map &old_to_new_map); +isl::union_map DependenceAnalysis(const isl::union_map &sources, const isl::union_map &targets, + const isl::union_map &kills, const isl::union_map &sch); +isl::union_map ComputeAllDependences(const isl::schedule &schedule, const isl::union_map &reads_um, + const isl::union_map &writes_um); +isl::schedule_node GetOuterBand(const isl::schedule_node &root); +bool IsSequenceOrSet(const isl::schedule_node &node); + +/* + * Compute copyin for each filter node, by intersecting the domains of + * reads and writes of the entire scop. + */ +isl::union_map ComputeFilterCopyin(const isl::schedule_node &node, const isl::union_map &ori_reads, + const isl::union_map &ori_writes, const isl::schedule ori_schedule); + +bool CompareSchTreeWithString(const std::string &compare_sch, const isl::schedule &sch); + +isl::schedule_constraints MakeScheduleConstraints(const isl::schedule &schedule, PassInfo &pass_info); + +isl::union_map RemoveReduceOpSelfDependence(ScopInfo &scop_info, PassInfo &pass_info); + +isl::union_map RemoveSelfDependence(PassInfo &pass_info); + +isl::union_map RemoveInvariantDependence(const isl::schedule &schedule, PassInfo &pass_info); + +/* + * Compute copyin for each filter and return the union of such copyins. + * In particular, return an empty result when the outermost band node + * is not a sequence/set node. + * + * "result" is the union of "copyin" from each filter node, which in + * turn is computed by ComputeFilterCopyin. + */ +isl::union_map ComputeFakeCopyin(const isl::schedule &schedule, const isl::union_map &fake_copyin, + const isl::union_map &ori_reads, const isl::union_map &ori_writes); + +} // namespace poly +} // namespace ir +} // namespace akg +#endif // POLY_PASS_H_ diff --git a/src/poly/schedule_pass/change_marknode_position.cc b/src/poly/schedule_pass/change_marknode_position.cc new file mode 100644 index 0000000000000000000000000000000000000000..7af58110cdf952c6c4f9e108a937b451cd81b2d9 --- /dev/null +++ b/src/poly/schedule_pass/change_marknode_position.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "change_marknode_position.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule ChangeMarkNodePosition::Run(isl::schedule curr_schedule) { + std::unordered_set ids = with_stmts_ids_; + if (ids.empty()) { + return curr_schedule; + } + + auto fn = [&ids](isl::schedule_node node) -> isl::schedule_node { + if (node.isa()) { + std::string mark_id = node.as().get_id().get_name(); + if (mark_id == "realize_UB" && node.child(0).isa()) { + if (node.child(0).child(0).isa()) { + node = node.get_child(0).get_child(0); // sequence + bool delete_outer_mark = true; + int n = node.n_children(); + for (int i = 0; i < n; i++) { + isl::schedule_node_filter filter_node = node.child(i).as(); + bool is_not_with_stmt = filter_node.get_filter().every_set( + [&ids](const isl::set &s) -> bool { return (ids.count(s.get_tuple_name()) == 0); }); + if (is_not_with_stmt) { + delete_outer_mark = false; + } else { + node = node.child(i).child(0); + node = node.insert_mark(isl::id(node.ctx(), mark_id)); + node = node.parent().parent(); + } + } + node = node.parent().parent(); + if (delete_outer_mark) { + node = node.del(); + } + } + } + } + return node; + }; + + return curr_schedule.get_root().map_descendant_bottom_up(fn).get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/change_marknode_position.h b/src/poly/schedule_pass/change_marknode_position.h new file mode 100644 index 0000000000000000000000000000000000000000..326c9c24d12671f65ac1d533ad7f6fe20a30f437 --- /dev/null +++ b/src/poly/schedule_pass/change_marknode_position.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_CHANGE_MARKNODE_POSITION_H_ +#define POLY_CHANGE_MARKNODE_POSITION_H_ + +#include "poly/schedule_pass.h" +#include + +namespace akg { +namespace ir { +namespace poly { + +/* + * "with" stmt aims to work around the irregular problem. + * By default, the "realize_UB" mark is on the outer band. However, for tensor-of-tensor, + * the intermediate tensor may be too large if realized in the outermost scope. + * To narrow down the scope, we move "realize_UB" mark to the filter node. + * If all filter nodes of the band are "with" stmts, we remove the outer "realize_UB" mark. + */ +class ChangeMarkNodePosition : public SchedulePass { + public: + ChangeMarkNodePosition(const std::unordered_set &with_stmts_ids) : with_stmts_ids_(with_stmts_ids) { + pass_name_ = __FUNCTION__; + }; + ~ChangeMarkNodePosition(){}; + + virtual isl::schedule Run(isl::schedule sch); + + private: + std::unordered_set with_stmts_ids_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_CHANGE_MARKNODE_POSITION_H_ diff --git a/src/poly/schedule_pass/compute_inner_band_dependency.cc b/src/poly/schedule_pass/compute_inner_band_dependency.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae080f279af8b3c9975178e07acd0261c2cdc1ff --- /dev/null +++ b/src/poly/schedule_pass/compute_inner_band_dependency.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "compute_inner_band_dependency.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule ComputeInnerBandDependency::Run(isl::schedule sch) { + auto ori_reads = scop_info_.analysis_result_.GetReads(); + auto ori_writes = scop_info_.analysis_result_.GetWrites(); + auto ori_fake_copyin = scop_info_.analysis_result_.GetFakeCopyin(); + auto inner_band_dependency = + ComputeFakeCopyin(sch, ori_fake_copyin, ori_reads, ori_writes).subtract(scop_info_.analysis_result_.GetCopyin()); + scop_info_.analysis_result_.RecordInnerBandDependency(inner_band_dependency); + return sch; +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/compute_inner_band_dependency.h b/src/poly/schedule_pass/compute_inner_band_dependency.h new file mode 100644 index 0000000000000000000000000000000000000000..8688a00adcaee571ef98bc43e3752aaa041e8f2a --- /dev/null +++ b/src/poly/schedule_pass/compute_inner_band_dependency.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_ +#define POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_ + +#include "poly/schedule_pass.h" +#include "poly/scop_info.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * This class initialises the inner band dependency information used in InjectMulticoreToSchedule pass + * and record it in the scop info. No actual schedule tree transfrom is performed in this pass. + */ +class ComputeInnerBandDependency : public SchedulePass { + public: + ComputeInnerBandDependency(ScopInfo &scop_info) : scop_info_(scop_info) { pass_name_ = __FUNCTION__; } + ~ComputeInnerBandDependency() {} + + virtual isl::schedule Run(isl::schedule sch); + + private: + ScopInfo &scop_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_ diff --git a/src/poly/schedule_pass/compute_schedule.cc b/src/poly/schedule_pass/compute_schedule.cc new file mode 100644 index 0000000000000000000000000000000000000000..3427f2340d442a840ed83381e2c12f242de14627 --- /dev/null +++ b/src/poly/schedule_pass/compute_schedule.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "compute_schedule.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::union_map ComputeSchedule::ModDependences(const isl::union_map &dependences) { + isl::union_map umap = isl::union_map::empty(dependences.ctx()); + dependences.foreach_map([&](const isl::map &m) -> void { + isl::map mm = m; + if (mm.get_tuple_id(isl_dim_in) != mm.get_tuple_id(isl_dim_out)) { + isl_map *pmap = mm.copy(); + int n_in = isl_map_dim(pmap, isl_dim_in); + for (int i = 0; i < n_in; ++i) { + pmap = isl_map_plain_update_val_if_fixed(pmap, isl_dim_in, i); + } + mm = isl::manage(pmap); + } + umap = umap.unite(isl::union_map(mm)); + }); + return umap; +} + +void ComputeSchedule::SetIslOptions() { + auto ctx = pass_info_.constraints_.ctx().get(); + int status = isl_options_set_schedule_unit_max_var_coefficient_sum(ctx, 1); + CHECK(status == isl_stat_ok); + + if (scop_info_.user_config_.GetComputeReschedule()) { + status = isl_options_set_schedule_whole_component(ctx, 0); + CHECK(status == isl_stat_ok); + } else { + status = isl_options_set_schedule_maximize_coincidence(ctx, 0); + CHECK(status == isl_stat_ok); + status = isl_options_set_schedule_whole_component(ctx, 1); + CHECK(status == isl_stat_ok); + } + + if (scop_info_.user_config_.GetDisableScheduleShift()) { + status = isl_options_set_schedule_max_constant_term(ctx, 0); + CHECK(status == isl_stat_ok); + status = isl_options_set_schedule_nonneg_var_coefficient(ctx, 1); + CHECK(status == isl_stat_ok); + } + + if (scop_info_.user_config_.GetEnableScheduleMaxConstant()) { + status = isl_options_set_schedule_max_constant_term(ctx, 0); + CHECK(status == isl_stat_ok); + } + + if (scop_info_.user_config_.GetDisableLoopReversal()) { + status = isl_options_set_schedule_nonneg_var_coefficient(ctx, 1); + CHECK(status == isl_stat_ok); + } + + if (scop_info_.user_config_.GetDisableLoopFusion()) { + status = isl_options_set_schedule_serialize_sccs(ctx, 1); + CHECK(status == isl_stat_ok); + } +} + +isl::schedule ComputeSchedule::Run(isl::schedule sch) { + if (scop_info_.user_config_.GetModScheduleShift()) { + pass_info_.dependences_ = ModDependences(pass_info_.dependences_); + } + pass_info_.constraints_ = MakeScheduleConstraints(sch, pass_info_); + SetIslOptions(); + return pass_info_.constraints_.compute_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/compute_schedule.h b/src/poly/schedule_pass/compute_schedule.h new file mode 100644 index 0000000000000000000000000000000000000000..a52f07ba774bbabe92b83b30d5405f4edbf72c45 --- /dev/null +++ b/src/poly/schedule_pass/compute_schedule.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_COMPUTE_SCHEDULE_H_ +#define POLY_COMPUTE_SCHEDULE_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * compute schedule pass, the main tasks ars as follow + * 1. modify the dependences depends on the configuration switch + * 2. Generating constraints + * 3. According to the constraints, the ISL interface is called to generate a new schedule + */ +class ComputeSchedule : public SchedulePass { + public: + ComputeSchedule(PassInfo &pass_info, ScopInfo &scop_info) : pass_info_(pass_info), scop_info_(scop_info) { + pass_name_ = __FUNCTION__; + } + ~ComputeSchedule() {} + + virtual isl::schedule Run(isl::schedule sch); + + void SetIslOptions(); + + isl::union_map ModDependences(const isl::union_map &dependences); + + private: + PassInfo &pass_info_; + + ScopInfo &scop_info_; +}; +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_COMPUTE_SCHEDULE_H_ diff --git a/src/poly/schedule_pass/compute_transfer_copyin.cc b/src/poly/schedule_pass/compute_transfer_copyin.cc new file mode 100644 index 0000000000000000000000000000000000000000..b51672c19bb5c3bf7a93c20587bcf708d2117870 --- /dev/null +++ b/src/poly/schedule_pass/compute_transfer_copyin.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "compute_transfer_copyin.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule ComputeTransferCopyin::Run(isl::schedule sch) { + // compute fake copyin + auto ori_reads = scop_info_.analysis_result_.GetReads(); + auto ori_writes = scop_info_.analysis_result_.GetWrites(); + auto ori_fake_copyin = scop_info_.analysis_result_.GetFakeCopyin(); + isl::union_map fake_copyin = ComputeFakeCopyin(sch, ori_fake_copyin, ori_reads, ori_writes); + fake_copyin = fake_copyin.subtract(scop_info_.analysis_result_.GetCopyin()); + scop_info_.analysis_result_.RecordFakeCopyin(fake_copyin); + isl::union_map raw_writes = ori_writes.domain_factor_domain(); + isl::union_map raw_reads = ori_reads.domain_factor_domain(); + isl::union_map raw_copyin = scop_info_.analysis_result_.GetCopyin().domain_factor_domain(); + isl::union_map reads = fake_copyin.domain_factor_domain(); + isl::union_map transfer_copyin = fake_copyin; + while (!reads.is_empty()) { + isl::union_map writes = raw_writes.intersect_range(reads.range()); + isl::union_map dependence = DependenceAnalysis(writes, reads, writes, sch.get_map()); + isl::union_set stmt = dependence.domain().universe(); + scop_info_.analysis_result_.RecordTransferStmt(scop_info_.analysis_result_.GetTransferStmt().unite(stmt)); + reads = raw_reads.intersect_domain(stmt); + + // compute transfer copyin + isl::union_map target_acc = raw_writes.intersect_domain(stmt); + isl::union_map relation = target_acc.reverse().apply_range(reads); + transfer_copyin = transfer_copyin.apply_range(relation); + isl::union_map copyin = transfer_copyin.intersect_range(raw_copyin.range().universe()); + scop_info_.analysis_result_.RecordReads(scop_info_.analysis_result_.GetReads().unite(copyin)); + scop_info_.analysis_result_.RecordCopyin(scop_info_.analysis_result_.GetCopyin().unite(copyin)); + transfer_copyin = transfer_copyin.subtract(copyin); + reads = reads.subtract(raw_copyin); + reads = reads.subtract(fake_copyin.domain_factor_domain()); + } + return sch; +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/compute_transfer_copyin.h b/src/poly/schedule_pass/compute_transfer_copyin.h new file mode 100644 index 0000000000000000000000000000000000000000..bdb364155b4df284aed2b3e8490e5bbeb250fd80 --- /dev/null +++ b/src/poly/schedule_pass/compute_transfer_copyin.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_COMPUTE_TRANSFER_COPYIN_H_ +#define POLY_COMPUTE_TRANSFER_COPYIN_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * This class initialises the transfer copyin information used in TransferStmt and MemoryPromotion pass + * and record it in the scop info. No actual schedule tree transfrom is performed in this pass. + */ +class ComputeTransferCopyin : public SchedulePass { + public: + ComputeTransferCopyin(ScopInfo &scop_info, PassInfo &pass_info) : scop_info_(scop_info), pass_info_(pass_info) { + pass_name_ = __FUNCTION__; + } + ~ComputeTransferCopyin() {} + + virtual isl::schedule Run(isl::schedule sch); + + private: + ScopInfo &scop_info_; + PassInfo &pass_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_COMPUTE_TRANSFER_COPYIN_H_ diff --git a/src/poly/schedule_pass/group.cc b/src/poly/schedule_pass/group.cc new file mode 100644 index 0000000000000000000000000000000000000000..b8559128d0fb97b9233e46752aed4223e5221a88 --- /dev/null +++ b/src/poly/schedule_pass/group.cc @@ -0,0 +1,314 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "group.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "poly/dump_log.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule GroupStatements::Run(isl::schedule sch_group) { + int cluster_id = 0; + pass_info_.has_grouped_ = false; + auto fn = [&cluster_id, this](isl::schedule_node node) -> isl::schedule_node { + if (node.isa() && node.n_children() > 1 && + !node.parent().isa()) { + isl::schedule_node_sequence seq_node = node.as(); + bool should_group = true; + isl::union_set_list filter_list(node.ctx(), seq_node.n_children()); + + for (unsigned int i = 0; i < seq_node.n_children(); i++) { + isl::schedule_node child = seq_node.child(i); + if (!child.isa() || !child.child(0).isa()) { + should_group = false; + break; + } else { + isl::schedule_node_filter filnode = child.as(); + filter_list = filter_list.add(filnode.get_filter()); + } + } + if (should_group) { + pass_info_.has_grouped_ = true; + isl::id gid = isl::id(node.ctx(), std::string("group") + std::to_string(cluster_id)); + pass_info_.group_filter_map_[gid] = filter_list; + cluster_id++; + isl_schedule_node *snode = isl_schedule_node_group(node.copy(), gid.release()); + node = isl::manage(snode); + } + } + return node; + }; + sch_group = sch_group.get_root().map_descendant_bottom_up(fn).get_schedule(); + if (pass_info_.has_grouped_) { + ComputeDependenceList(); + GroupDependence(sch_group); + } + return sch_group; +} + +void GroupStatements::GroupDependence(const isl::schedule &schedule) { + isl::schedule_node rnode = schedule.get_root().child(0); + isl_union_pw_multi_aff *contract = isl_schedule_node_get_subtree_contraction(rnode.get()); + pass_info_.group_upma_ = isl::manage(contract); + isl::union_map gmap = isl::union_map::from(pass_info_.group_upma_); + pass_info_.dependences_ = pass_info_.dependences_.apply_range(gmap).apply_domain(gmap); +} + +void GroupStatements::ComputeDependenceList() { + pass_info_.dependences_.foreach_map([&](const isl::map &m) -> void { + if (m.domain().get_tuple_id() != m.range().get_tuple_id()) { + isl::space domain_space_obj = m.domain().get_space(); + isl::local_space domain_space = isl::manage(isl_local_space_from_space(domain_space_obj.copy())); + int dim = m.dim(isl_dim_in); + int64_t weight = 1; + for (int i = 0; i < dim; ++i) { + isl::aff get_dim_in_domain = isl::aff::var_on_domain(domain_space, isl_dim_out, i); + int max = static_cast(m.domain().max_val(get_dim_in_domain).get_num_si()); + int min = static_cast(m.domain().min_val(get_dim_in_domain).get_num_si()); + weight *= (max - min + 1); + } + Dependency dependency(m.domain().get_tuple_id(), m.range().get_tuple_id(), weight); + pass_info_.dependency_list_.push_back(dependency); + } + }); +} + +void UnGroupStatements::IsContainsCircle(const std::vector> &graph, std::vector &vis, int node, + int size) { + vis[node] = 1; + for (int i = 0; i < size; ++i) { + if (graph[node][i] != 0) { + if (vis[node] == 1) { + is_circle_ = true; + break; + } else if (vis[node] == -1) { + continue; + } else { + IsContainsCircle(graph, vis, i, size); + } + } + } + vis[node] = -1; +} + +void UnGroupStatements::DfsTopsort(std::vector> &graph, std::vector &indegree, + std::set &zeros, int cnt, int size, int64_t current_value, + int64_t current_max) { + cnt_dfs_times_++; + // constraint 1: return when dfs reaches a limit times. + if (cnt_dfs_times_ > DFS_TIMES_MAX) return; + // constraint 2: return when current max is bigger than best result. + if ((min_topsort_ != -1) && (current_max >= min_topsort_)) return; + + if (cnt == size) { + min_topsort_ = current_max; + std::vector res(temp_res_); + topsort_res_ = res; + } else { + for (auto it = zeros.begin(); it != zeros.end(); ++it) { + std::set zeros_copy(zeros); + zeros_copy.erase(*it); + temp_res_[cnt] = *it; + std::vector temp; + + for (int j = 0; j < size; ++j) { + if (graph[*it][j] == 1) { + graph[*it][j] = 0; + indegree[j]--; + if (indegree[j] == 0) { + zeros_copy.insert(j); + } + temp.emplace_back(j); + } + } + int64_t updated_value = current_value; + if (cost_map_.find(temp_res_[cnt]) != cost_map_.end()) { + updated_value += cost_map_.find(temp_res_[cnt])->second; + } + DfsTopsort(graph, indegree, zeros_copy, cnt + 1, size, updated_value, std::max(updated_value, current_max)); + for (int &itj : temp) { + graph[*it][itj] = 1; + indegree[itj]++; + } + } + } +} + +isl::union_set_list UnGroupStatements::DependenciesTopsort(const isl::union_set_list &filterlist) { + if (pass_info_.dependency_list_.empty()) return filterlist; + if (filterlist.size() == 0) return filterlist; + + // 1. build graph from dependency_list_ and filterlist + int graph_size = filterlist.size(); + std::unordered_map filter_map; + for (int i = 0; i < graph_size; ++i) { + isl::union_set temp = filterlist.get_at(i); + CHECK(temp.n_set() == 1u) << "number of union_set's children in filterlist should be 1"; + filter_map.insert(std::pair(temp.get_set_list().get_at(0).get_tuple_id(), i)); + } + + std::vector> graph(graph_size, std::vector(graph_size, 0)); + std::vector indegree(graph_size, 0); + for (auto &i : pass_info_.dependency_list_) { + isl::id from = i.GetStartNode(); + isl::id to = i.GetEndNode(); + if (filter_map.find(from) != filter_map.end() && filter_map.find(to) != filter_map.end()) { + int x = filter_map.find(from)->second; + int y = filter_map.find(to)->second; + // we only use similar dependency once + if (graph[x][y] == 0) { + graph[x][y] = 1; + indegree[y]++; + } + int64_t value; + if (cost_map_.find(x) == cost_map_.end()) { + value = i.GetEdgeWeight(); + } else { + value = cost_map_.find(x)->second + i.GetEdgeWeight(); + } + cost_map_.insert(std::pair(x, value)); + + if (cost_map_.find(y) == cost_map_.end()) { + value = -i.GetEdgeWeight(); + } else { + value = cost_map_.find(y)->second - i.GetEdgeWeight(); + } + cost_map_.insert(std::pair(y, value)); + } + } + // 2. judge if graph has a circle by using dfs + std::vector vis(graph_size, 0); + is_circle_ = false; + for (int i = 0; i < graph_size; ++i) { + if (vis[i] == -1) { + continue; + } + IsContainsCircle(graph, vis, i, graph_size); + if (is_circle_) return filterlist; + } + // 3. compute all the Topsort list + if (temp_res_.empty()) { + temp_res_.insert(temp_res_.begin(), graph_size, 0); + } else { + temp_res_.assign(graph_size, 0); + } + std::set zeros; + for (int i = 0; i < graph_size; ++i) { + if (indegree[i] == 0) { + zeros.insert(i); + } + } + // minTopsort == -1 means never found a result of dfs. + min_topsort_ = -1; + cnt_dfs_times_ = 0; + DfsTopsort(graph, indegree, zeros, 0, graph_size, 0, 0); + + // 4. output the smallest filterlist + isl::union_set_list reslist = isl::union_set_list(filterlist.ctx(), graph_size); + for (int i = 0; i < graph_size; ++i) { + isl::union_set temp = filterlist.get_at(topsort_res_[i]); + reslist = reslist.add(temp); + } + return reslist; +} + +isl::schedule_node UnGroupStatements::InsertMarknode(isl::schedule_node node, const isl::id &gid) { + if (node.isa()) { + return node.insert_mark(gid); + } else { + if (node.n_children() == 1) { + node = InsertMarknode(node.child(0), gid); + node = node.parent(); + } + return node; + } +} + +isl::schedule UnGroupStatements::Run(isl::schedule schedule) { + if (!pass_info_.has_grouped_) { + return schedule; + } + bool find_filter = false; + auto findAndMarkGroupFilter = [this, &find_filter](isl::schedule_node node) -> isl::schedule_node { + if (node.isa() && node.as().n_children() == 1) { + find_filter = true; + auto filter_node = node.as().first_child(); + isl::map_list schmap = filter_node.get_prefix_schedule_union_map().get_map_list(); + if (schmap.size() == 1) { + isl::id gid = schmap.get_at(0).domain().get_tuple_id(); + if (pass_info_.group_filter_map_.find(gid) != pass_info_.group_filter_map_.end()) { + node = InsertMarknode(node, gid); + } + } + } + if ((node.isa()) && (!find_filter)) { + find_filter = true; + isl::union_set domain = node.as().domain(); + isl::set_list setlist = domain.get_set_list(); + isl::id groupid; + if (setlist.size() == 1) { + groupid = setlist.get_at(0).get_tuple_id(); + } + if (pass_info_.group_filter_map_.find(groupid) != pass_info_.group_filter_map_.end()) { + while (node.has_children()) { + if (node.n_children() > 1) { + return node.root(); + } else { + node = node.first_child(); + } + } + node = InsertMarknode(node, groupid); + node = node.root(); + } + } + return node; + }; + schedule = schedule.get_root().map_descendant_bottom_up(findAndMarkGroupFilter).get_schedule(); + + schedule = schedule.pullback(pass_info_.group_upma_); + + auto ReplaceUngroupedFilterWithSequence = [this](isl::schedule_node node) -> isl::schedule_node { + if (node.isa()) { + isl::schedule_node_mark marknode = node.as(); + isl::id markid = marknode.get_id(); + isl::union_set_list filterlist = pass_info_.group_filter_map_[markid]; + isl::union_set_list resfilterlist = DependenciesTopsort(filterlist); + if (pass_info_.group_filter_map_.find(markid) != pass_info_.group_filter_map_.end()) { + node = node.del(); + node = node.insert_sequence(resfilterlist); + } + } + return node; + }; + schedule = schedule.get_root().map_descendant_bottom_up(ReplaceUngroupedFilterWithSequence).get_schedule(); + pass_info_.dependences_ = pass_info_.orig_dependences_; + pass_info_.constraints_ = MakeScheduleConstraints(schedule, pass_info_); + return schedule; +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/group.h b/src/poly/schedule_pass/group.h new file mode 100644 index 0000000000000000000000000000000000000000..0b6c5dce776922028378a7eefb87bd72ca22857a --- /dev/null +++ b/src/poly/schedule_pass/group.h @@ -0,0 +1,85 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_GROUP_H_ +#define POLY_GROUP_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * This class combines multiple statements into a group to accelerate the calculation process of Pluto algorithm + */ +class GroupStatements : public SchedulePass { + public: + GroupStatements(PassInfo &pass_info) : pass_info_(pass_info) { pass_name_ = __FUNCTION__; } + ~GroupStatements() {} + + virtual isl::schedule Run(isl::schedule sch); + + void GroupDependence(const isl::schedule &schedule); + + void ComputeDependenceList(); + + private: + PassInfo &pass_info_; +}; + +/* + * After compute schedule, this class will restore the group statements to the original statement sequence + */ +class UnGroupStatements : public SchedulePass { + public: + UnGroupStatements(PassInfo &pass_info) : pass_info_(pass_info) { pass_name_ = __FUNCTION__; } + ~UnGroupStatements() {} + + virtual isl::schedule Run(isl::schedule sch); + + void IsContainsCircle(const std::vector> &graph, std::vector &vis, int node, int size); + + void DfsTopsort(std::vector> &graph, std::vector &indegree, std::set &zeros, int cnt, + int size, int64_t current_value, int64_t current_max); + isl::union_set_list DependenciesTopsort(const isl::union_set_list &filterlist); + + isl::schedule_node InsertMarknode(isl::schedule_node node, const isl::id &gid); + + private: + PassInfo &pass_info_; + bool is_circle_ = false; + + // the maximum times of dfs Topsort + const int DFS_TIMES_MAX = 1000000; + + // counter times of dfs Topsort for limiting a long-time dfs process + int cnt_dfs_times_ = 0; + + // the min total cost for dfs Topsort + int64_t min_topsort_ = -1; + + std::map cost_map_; + + std::vector topsort_res_; + + std::vector temp_res_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_GROUP_H_ diff --git a/src/poly/schedule_pass/init_schedule.cc b/src/poly/schedule_pass/init_schedule.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed6b514c805774c8fc3b1aa06b5d422900155cc2 --- /dev/null +++ b/src/poly/schedule_pass/init_schedule.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "init_schedule.h" + +namespace akg { +namespace ir { +namespace poly { + +void InitSchedule::RemoveUninitializedCopyin(isl::union_map ©_in, const Binds &binds) { + isl::union_set copyin_range = copy_in.range(); + auto ForeachSetFn = [©_in, &binds](const isl::set &set) { + std::string tensor_name = set.get_tuple_name(); + bool defined = false; + for (auto bind : binds) { + if (bind.first->op->name == tensor_name) { + defined = true; + } + } + if (!defined) { + copy_in = copy_in.subtract_range(set); + LOG(WARNING) << "remove uninitialized copyin, tensor name=" << tensor_name << ", access=" << set; + } + }; + copyin_range.foreach_set(ForeachSetFn); +} + +void InitSchedule::ModDependencesBeforeGroup(const isl::schedule &schedule) { + if (!scop_info_.cube_info_.IsSpecGemm()) { + if (scop_info_.user_config_.GetRemoveSelfDependence()) { + pass_info_.dependences_ = RemoveReduceOpSelfDependence(scop_info_, pass_info_); + } + + if (scop_info_.user_config_.GetForceRemoveSelfDependence()) { + pass_info_.dependences_ = RemoveSelfDependence(pass_info_); + } + } + + if (scop_info_.user_config_.GetRemoveInvariantDependence()) { + pass_info_.dependences_ = RemoveInvariantDependence(schedule, pass_info_); + } +} + +void InitSchedule::ComputeCopyIn(const isl::schedule &schedule) { + auto reads = scop_info_.analysis_result_.GetReads().domain_factor_domain(); + auto writes = scop_info_.analysis_result_.GetWrites().domain_factor_domain(); + auto uai = isl::union_access_info(reads); + uai = uai.set_kill(writes); + uai = uai.set_may_source(writes); + uai = uai.set_schedule(schedule); + auto flow = uai.compute_flow(); + auto mayNoSource = flow.get_may_no_source(); + scop_info_.analysis_result_.RecordCopyin(scop_info_.analysis_result_.GetReads().intersect_range(mayNoSource.range())); +} + +isl::schedule InitSchedule::Run(isl::schedule sch) { + ComputeCopyIn(sch); + RemoveUninitializedCopyin(scop_info_.analysis_result_.GetCopyin(), scop_info_.user_config_.GetOriginBind()); + + pass_info_.dependences_ = ComputeAllDependences(sch, scop_info_.analysis_result_.GetReads(), + scop_info_.analysis_result_.GetWrites()); + pass_info_.orig_dependences_ = pass_info_.dependences_; + ModDependencesBeforeGroup(sch); + return sch; +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/init_schedule.h b/src/poly/schedule_pass/init_schedule.h new file mode 100644 index 0000000000000000000000000000000000000000..eec515a3645ae0dfa8088bc69078eac83d7630ab --- /dev/null +++ b/src/poly/schedule_pass/init_schedule.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_INIT_SCHEDULE_H_ +#define POLY_INIT_SCHEDULE_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Init schedule pass, the main tasks ars as follow + * 1. compute copyin + * 2. compute dependence + * 3. modify the dependence according to the specific scene + */ +class InitSchedule : public SchedulePass { + public: + InitSchedule(PassInfo &pass_info, ScopInfo &scop_info) : pass_info_(pass_info), scop_info_(scop_info) { + pass_name_ = __FUNCTION__; + } + ~InitSchedule() {} + + virtual isl::schedule Run(isl::schedule sch); + + void ComputeCopyIn(const isl::schedule &schedule); + void RemoveUninitializedCopyin(isl::union_map ©_in, const Binds &binds); + + void ModDependencesBeforeGroup(const isl::schedule &schedule); + + private: + PassInfo &pass_info_; + + ScopInfo &scop_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_INIT_SCHEDULE_H_ diff --git a/src/poly/schedule_pass/insert_node_for_allocc.cc b/src/poly/schedule_pass/insert_node_for_allocc.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a42967392042572316ac41e8c8b8cb4a7d7ae0a --- /dev/null +++ b/src/poly/schedule_pass/insert_node_for_allocc.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "insert_node_for_allocc.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule_node InsertNodeForAllocCImpl(isl::schedule_node node) { + if (node.isa()) { + if (node.as().get_id().get_name() == REALIZE_L1) { + node = node.del(); + node = + node.as().split(static_cast(node.as().n_member()) - 1); + node = node.child(0); + node = node.insert_mark(isl::id(node.ctx(), REALIZE_L1)); + node = node.insert_mark(isl::id(node.ctx(), ALLOC_C)); + node = node.parent(); + } + } + return node; +} + +isl::schedule InsertNodeForAllocC::Run(isl::schedule sched) { + // add alloc_C + sched = sched.get_root().map_descendant_bottom_up(InsertNodeForAllocCImpl).get_schedule(); + return sched; +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/sink_axis.h b/src/poly/schedule_pass/insert_node_for_allocc.h similarity index 61% rename from src/poly/sink_axis.h rename to src/poly/schedule_pass/insert_node_for_allocc.h index bbb1bb3e74bf5463cc9f227ea283811ac9071b4d..b2ff804944e2b7a44617a501802df5392c6d7677 100644 --- a/src/poly/sink_axis.h +++ b/src/poly/schedule_pass/insert_node_for_allocc.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,24 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef POLY_SINK_AXIS_H_ -#define POLY_SINK_AXIS_H_ +#ifndef POLY_INSERT_NODE_FOR_ALLOCC_H_ +#define POLY_INSERT_NODE_FOR_ALLOCC_H_ -#pragma once -#include "poly/transform.h" - -#define MAX_STRIDE 65535 +#include "poly/schedule_pass.h" namespace akg { namespace ir { namespace poly { -bool FindC0Schedule(const isl::pw_aff_list &paList); -void ExchangeCoincident(std::vector &coincident, const isl::schedule_node &node, - const std::unordered_map lastIdxSchedule, const int &n); +class InsertNodeForAllocC : public SchedulePass { + public: + InsertNodeForAllocC() { pass_name_ = __FUNCTION__; }; + ~InsertNodeForAllocC(){}; + + virtual isl::schedule Run(isl::schedule sched); +}; } // namespace poly } // namespace ir } // namespace akg -#endif // POLY_SINK_AXIS_H_ +#endif // POLY_INSERT_NODE_FOR_ALLOCC_H_ diff --git a/src/poly/schedule_pass/keep_outer_band_order.cc b/src/poly/schedule_pass/keep_outer_band_order.cc new file mode 100644 index 0000000000000000000000000000000000000000..95973e653ea8c3d6c4f034848329282f887b5eb1 --- /dev/null +++ b/src/poly/schedule_pass/keep_outer_band_order.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "keep_outer_band_order.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule KeepOuterBandOrder::Run(isl::schedule sch) { + auto outer_band_node = GetOuterBand(sch.get_root()); + if (!outer_band_node.isa()) { + return sch; + } + auto outer_band = outer_band_node.as(); + if (!outer_band.get_permutable()) { + return sch; + } + auto mupa = outer_band.get_partial_schedule(); + auto n_member = mupa.size(); + // rank outer band members according to input order + std::multimap axes_scores; // map score to old axes + for (auto i = 0u; i < n_member; ++i) { + auto upa = mupa.get_at(i); + size_t axis_score = 0; + upa.get_pw_aff_list().foreach([&](const isl::pw_aff &pw_aff) { + pw_aff.foreach_piece([&](const isl::set &set, const isl::aff &aff) { + size_t n_dims = isl_aff_dim(aff.get(), isl_dim_in); + CHECK_GE(n_dims, 0); + // vector + size_t min_dim_in_aff = 0; + if (info_.cube_info_.HasCube()) { + // cube + min_dim_in_aff = n_dims; + } + for (auto j = 0u; j < n_dims; ++j) { + auto coef_val = isl_aff_get_coefficient_val(aff.get(), isl_dim_in, j); + if (isl_val_get_num_si(coef_val) != 0) { + min_dim_in_aff = j; + break; + } + static_cast(isl_val_free(coef_val)); + } + axis_score += min_dim_in_aff; + }); + }); + axes_scores.insert(std::make_pair(axis_score, i)); + } + + std::vector axes_map; // new axes to old axes map + for (auto it : axes_scores) { + axes_map.push_back(it.second); + } + + // construct new outer band according to the axes map + isl::union_pw_aff_list new_upal; + for (auto i = 0u; i < n_member; ++i) { + if (i == 0) { + new_upal = isl::union_pw_aff_list(mupa.get_at(axes_map[i])); + } else { + new_upal = new_upal.add(mupa.get_at(axes_map[i])); + } + } + + auto new_mupa = isl::multi_union_pw_aff(mupa.get_space(), new_upal); + + // save permutable and coincident of old node + bool permutable = outer_band.get_permutable(); + std::vector coincident; + for (auto i = 0u; i < n_member; ++i) { + coincident.push_back(outer_band.member_get_coincident(axes_map[i])); + } + if (!info_.cube_info_.HasCube()) { + coincident[0] = true; + } + + // delete old node + outer_band_node = outer_band_node.del(); + + // insert new node + outer_band_node = outer_band_node.insert_partial_schedule(new_mupa); + outer_band_node = outer_band_node.as().set_permutable(permutable); + for (auto i = 0u; i < n_member; ++i) { + outer_band_node = outer_band_node.as().member_set_coincident(i, coincident[i]); + } + + return outer_band_node.get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/keep_outer_band_order.h b/src/poly/schedule_pass/keep_outer_band_order.h new file mode 100644 index 0000000000000000000000000000000000000000..2c66bb3843f3dbc8f53b1cc7b75e0cb8139041ac --- /dev/null +++ b/src/poly/schedule_pass/keep_outer_band_order.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_KEEP_OUTER_BAND_ORDER_H_ +#define POLY_KEEP_OUTER_BAND_ORDER_H_ + +#include "poly/schedule_pass.h" +#include "poly/scop_info.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Reorder axes in the outer band to the same as input IR. + */ +class KeepOuterBandOrder : public SchedulePass { + public: + KeepOuterBandOrder(ScopInfo &info) : info_(info) { pass_name_ = __FUNCTION__; } + ~KeepOuterBandOrder() {} + + virtual isl::schedule Run(isl::schedule sch); + + private: + ScopInfo &info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_KEEP_OUTER_BAND_ORDER_H_ \ No newline at end of file diff --git a/src/poly/schedule_pass/label_realize_out_position.cc b/src/poly/schedule_pass/label_realize_out_position.cc new file mode 100644 index 0000000000000000000000000000000000000000..b6e683075035c48af50ba4698b2843a953a4e8b1 --- /dev/null +++ b/src/poly/schedule_pass/label_realize_out_position.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "label_realize_out_position.h" +#include + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule LabelRealizeOutPosition::Run(isl::schedule sch_label) { + auto fn_ = [](isl::schedule_node node) -> isl::schedule_node { + if (node.isa()) { + if (REALIZE_UB == node.as().get_id().get_name() && + node.child(0).isa()) { + auto band = node.child(0).as(); + + unsigned pos = UINT_MAX; + auto updatePos_ = [&pos](isl::schedule_node node) -> isl::schedule_node { + if (node.isa()) { + node = node.get_child(0); + if (node.isa()) { + auto band = node.as(); + CHECK_LT(band.n_member(), UINT_MAX); + for (unsigned i = 0; i < band.n_member(); ++i) { + if (!band.member_get_coincident(i)) { + if (i < pos) pos = i; + break; + } + } + } + node = node.parent(); + } + return node; + }; + + static_cast(band.map_descendant_bottom_up(updatePos_)); + + for (unsigned i = 0; i < band.n_member(); ++i) { + if (!band.member_get_coincident(i)) { + if (i < pos) pos = i; + break; + } + } + + if (pos < band.n_member()) { + node = node.del(); + node = node.as().split(pos); + node = node.child(0); + node = node.insert_mark(isl::id(node.ctx(), REALIZE_UB)); + node = node.insert_mark(isl::id(node.ctx(), ALLOC_REALIZE_OUT)); + node = node.parent(); + } + } + } + return node; + }; + return sch_label.get_root().map_descendant_bottom_up(fn_).get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/stmt_parse.cc b/src/poly/schedule_pass/label_realize_out_position.h similarity index 57% rename from src/poly/stmt_parse.cc rename to src/poly/schedule_pass/label_realize_out_position.h index cb311566a169b467b6bb06423d2aeac53cb07e10..0bb35ca40514718ce0b47e84975947c6d5a00bb4 100644 --- a/src/poly/stmt_parse.cc +++ b/src/poly/schedule_pass/label_realize_out_position.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,25 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "poly/stmt_parse.h" +#ifndef POLY_LABEL_REALIZE_OUT_POSITION_H_ +#define POLY_LABEL_REALIZE_OUT_POSITION_H_ -#include -#include -#include -#include -#include +#include "poly/schedule_pass.h" namespace akg { namespace ir { namespace poly { -static const char *PolyOpTypeKey[] = {FOREACH(GENERATE_STRING)}; -const char *getPolyOpTypeKey(PolyOpType type) { - int idx = static_cast(type); - const int num_type_keys = sizeof(PolyOpTypeKey) / sizeof(PolyOpTypeKey[0]); - CHECK(idx < num_type_keys) << "invalid type " << idx; - return PolyOpTypeKey[idx]; -} +class LabelRealizeOutPosition : public SchedulePass { + public: + LabelRealizeOutPosition() { pass_name_ = __FUNCTION__; }; + ~LabelRealizeOutPosition(){}; + + virtual isl::schedule Run(isl::schedule sch_label); +}; + } // namespace poly } // namespace ir } // namespace akg + +#endif // POLY_LABEL_REALIZE_OUT_POSITION_H_ diff --git a/src/poly/schedule_pass/mark_fuse_op.cc b/src/poly/schedule_pass/mark_fuse_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc1dd8d4c486e501d0aa49aaeba37ee85e4a73f1 --- /dev/null +++ b/src/poly/schedule_pass/mark_fuse_op.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mark_fuse_op.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule MarkFuseOp::Run(isl::schedule schedule_mark) { + auto fn = [](isl::schedule_node node) -> isl::schedule_node { + if (node.isa()) { + std::string mark_id = node.as().get_id().get_name(); + size_t pos = mark_id.find(UBL0); + if (pos != std::string::npos) { + std::string m = FUSE_VECTOR; + node = node.insert_mark(isl::id(node.ctx(), m)); + node = node.parent(); + } + } + return node; + }; + return schedule_mark.get_root().map_descendant_bottom_up(fn).get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/mark_fuse_op.h b/src/poly/schedule_pass/mark_fuse_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bbe0586bab625202fae23932ed71488d62df77a7 --- /dev/null +++ b/src/poly/schedule_pass/mark_fuse_op.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_MARK_FUSE_OP_H_ +#define POLY_MARK_FUSE_OP_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +class MarkFuseOp : public SchedulePass { + public: + MarkFuseOp() { pass_name_ = __FUNCTION__; }; + ~MarkFuseOp(){}; + + virtual isl::schedule Run(isl::schedule schedule_mark); +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_MARK_FUSE_OP_H_ diff --git a/src/poly/schedule_pass/mark_outer_most.cc b/src/poly/schedule_pass/mark_outer_most.cc new file mode 100644 index 0000000000000000000000000000000000000000..9857d197bf632f3e482e6fd35e57492ce9dff17d --- /dev/null +++ b/src/poly/schedule_pass/mark_outer_most.cc @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mark_outer_most.h" + +namespace akg { +namespace ir { +namespace poly { + +std::vector getIsolateVector(const isl::schedule_node_band &node) { + auto build_options = node.get_ast_build_options().get_set_list(); + std::vector isolate_vector(node.n_member(), true); + for (auto idx = 0u; idx < build_options.size(); ++idx) { + if (build_options.get_at(idx).get_tuple_name() == "isolate") { + const isl::set &isolate_set = build_options.get_at(idx); + for (int dim = 0; dim < static_cast(node.n_member()); dim++) { + isolate_vector[dim] = (isolate_set.simple_hull().dim_max_val(dim).get_num_si() > 0); + } + break; + } + } + return isolate_vector; +} + +bool InjectMulticoreToBand(isl::schedule_node &band_node) { + auto node = band_node.as(); + if (node.is_null()) return false; + if (node.n_member() < 1) return false; + if (!node.get_permutable()) return false; + + auto isolate_vector = getIsolateVector(node); + bool has_coincident = false; + std::string mark = "multicore_coincident"; + for (int dim = 0; dim < static_cast(node.n_member()); ++dim) { + bool is_dim_coincident = isolate_vector[dim] && node.member_get_coincident(dim); + has_coincident = has_coincident || is_dim_coincident; + mark += "_" + std::to_string(is_dim_coincident); + } + if (has_coincident) { + band_node = band_node.insert_mark(isl::id(band_node.ctx(), mark)); + } + return has_coincident; +} + +isl::schedule_node &ObtainSequenceOrSetNodeAncestor(isl::schedule_node &node) { + do { + node = node.parent(); + } while (!node.isa() && !node.isa()); + return node; +} + +bool InjectMulticoreToChildrenBands(isl::schedule_node &sequence_node) { + bool has_multicore = false; + for (unsigned int filter = 0; filter < sequence_node.n_children(); ++filter) { + auto filter_node = sequence_node.get_child(filter); + auto band = GetOuterBand(filter_node); + if (InjectMulticoreToBand(band)) { + has_multicore = true; + sequence_node = ObtainSequenceOrSetNodeAncestor(band); + } + } + return has_multicore; +} + +bool MarkOuterMost::SingleMulticoreBand(isl::schedule_node &outer_band) { + if (outer_band.as() || outer_band.as()) { + int multi_core_band = 0; + for (unsigned int i = 0; i < outer_band.n_children(); ++i) { + isl::schedule_node node = outer_band.get_child(i); + if (node.isa()) { + auto filter = node.as(); + if (filter.has_children()) { + auto node0 = filter.get_child(0); + if (node0.isa() && node0.as().n_member() >= 1) { + multi_core_band++; + } + } + } + } + if (multi_core_band == 1) { + return true; + } + } + return false; +} + +bool MarkOuterMost::InjectMulticoreToSchedule(isl::schedule_node &outer_band) { + if (outer_band.as()) { + return InjectMulticoreToBand(outer_band); + } else if (outer_band.as() || outer_band.as()) { + if (SingleMulticoreBand(outer_band)) { + for (unsigned int i = 0; i < outer_band.n_children(); ++i) { + isl::schedule_node node = outer_band.get_child(i); + if (node.isa()) { + auto filter = node.as(); + if (filter.has_children() && filter.get_child(0).isa() && + filter.get_child(0).as().n_member() >= 1) { + isl::schedule_node tmp = filter.get_child(0); + bool injected = InjectMulticoreToBand(tmp); + outer_band = ObtainSequenceOrSetNodeAncestor(tmp); + return injected; + } + } + } + } + bool is_bands_independent = scop_info_.analysis_result_.GetInnerBandDependency().is_empty(); + if (!is_bands_independent) { + // Conv outer bands indeed have inter-band dependency, but it will be fixed in post_fusion, + // so Conv can still use multicore. This is actually dangerous and may need double check. + if (!this->scop_info_.cube_info_.IsConv()) return false; + } + return InjectMulticoreToChildrenBands(outer_band); + } + return false; +} + +isl::schedule MarkOuterMost::Run(isl::schedule schedule_mark) { + isl::schedule_node root = schedule_mark.get_root(); + isl::schedule_node outer_band = GetOuterBand(root); + bool has_multi_core = InjectMulticoreToSchedule(outer_band); + if (has_multi_core) { + return outer_band.get_schedule(); + } else { + LOG(INFO) << "This operator is not capable of using multi-core. " + << "Possible reasons are: " + << "1) there is dependency between outer bands. " + << "2) outer bands are not tiled (only tiles of outer band can use multicore)."; + return schedule_mark; + } +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/mark_outer_most.h b/src/poly/schedule_pass/mark_outer_most.h new file mode 100644 index 0000000000000000000000000000000000000000..8eed5869b94098da4730a10873685bb93012fe71 --- /dev/null +++ b/src/poly/schedule_pass/mark_outer_most.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_MARK_OUTER_MOST_H_ +#define POLY_MARK_OUTER_MOST_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +std::vector getIsolateVector(const isl::schedule_node_band &node); +bool InjectMulticoreToBand(isl::schedule_node &band_node); +isl::schedule_node &ObtainSequenceOrSetNodeAncestor(isl::schedule_node &node); +bool InjectMulticoreToChildrenBands(isl::schedule_node &sequence_node); + +class MarkOuterMost : public SchedulePass { + public: + MarkOuterMost(ScopInfo &scop_info) : scop_info_(scop_info) { pass_name_ = __FUNCTION__; }; + ~MarkOuterMost(){}; + + virtual isl::schedule Run(isl::schedule schedule_mark); + + private: + bool InjectMulticoreToSchedule(isl::schedule_node &outer_band); + bool SingleMulticoreBand(isl::schedule_node &outer_band); + + ScopInfo &scop_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_MARK_OUTER_MOST_H_ diff --git a/src/poly/schedule_pass/memory_manager.cc b/src/poly/schedule_pass/memory_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..372f4ec5659b4691acb2baa1a9662e01e8e06d92 --- /dev/null +++ b/src/poly/schedule_pass/memory_manager.cc @@ -0,0 +1,771 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "memory_manager.h" +#include "poly/dma_inject.h" +#include "poly/scop_builder.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::union_set CollectDomain(const isl::schedule_node &node) { + int depth = node.get_tree_depth(); + isl::schedule_node tmp_node; + isl::union_set domain = node.get_domain(); + for (int i = 0; i < depth; ++i) { + tmp_node = node.ancestor(depth - i); + if (auto filter_node = tmp_node.as()) { + domain = domain.intersect(filter_node.get_filter()); + } + if (auto extension_node = tmp_node.as()) { + auto parent_schedule = ShortSchedule(tmp_node); + auto extension = extension_node.get_extension(); + parent_schedule = parent_schedule.intersect_domain(domain); + domain = domain.unite(parent_schedule.range().apply(extension)); + } + } + return domain; +} + +isl::schedule_node MapDescendantTopDown(isl::schedule_node node, + const std::function &fn) { + unsigned int depth_ = node.get_tree_depth(); + do { + do { + node = fn(node); + } while (node.has_children() && (node = node.first_child())); + + while (node.get_tree_depth() > depth_ && !node.has_next_sibling()) { + node = node.parent(); + } + + if (node.get_tree_depth() > depth_) { + node = node.next_sibling(); + } + } while (node.get_tree_depth() > depth_); + + return node; +} + +void GetVisitedStmts(const isl::schedule_node &root) { + int n = root.n_children(); + if (n <= 0) return; + + isl::schedule_node node; + if (root.isa()) { + isl::union_set visited_stmts; + for (int i = 0; i < n; ++i) { + node = root.child(i); + auto filter_node = node.as(); + CHECK(filter_node) << "expected children of sequence to be filters"; + auto filter = filter_node.get_filter().universe(); + if (visited_stmts.get()) { + CHECK(visited_stmts.intersect(filter).is_empty()) << "filters are expected to be disjoint as stmt level"; + visited_stmts = visited_stmts.unite(filter); + } else { + visited_stmts = filter; + } + } + } + + for (int i = 0; i < n; ++i) { + node = root.child(i); + GetVisitedStmts(node); + } +} + +std::vector CollectMarkNode(const isl::schedule_node &tree, const std::string &mark_tag) { + std::vector mark_nodes; + tree.foreach_descendant_top_down([&mark_nodes, &mark_tag](const isl::schedule_node &node) -> bool { + if (auto mark_node = node.as()) { + // ignore nested mark nodes + if (mark_node.get_id().get_name() == mark_tag) { + mark_nodes.push_back(node); + return false; + } + } + return true; + }); + return mark_nodes; +} + +isl::schedule MemoryManager::Run(isl::schedule sch) { + schedule_ = sch; + + AddStateTensorsDataFlow(); + ReorderBufferedDefInfos(); + + auto schedule = sch; + GetVisitedStmts(schedule.get_root()); + for (size_t index = 0; index < scop_info_.analysis_result_.buffer_def_infos_.size(); index++) { + if (scop_info_.analysis_result_.buffer_def_infos_[index].find_buffer) continue; + std::string mark_tag = scop_info_.analysis_result_.buffer_def_infos_[index].mark_tag; + if (scop_info_.analysis_result_.buffer_def_infos_[index].IsIm2col()) { + isl::id nextTensorId = scop_info_.analysis_result_.buffer_def_infos_[index].NextTensorDstId(); + mark_tag = scop_info_.analysis_result_.GetBufferDefInfo(nextTensorId).mark_tag; + } + schedule = HoistBufferFootprintAtMarkNode(schedule.get_root(), mark_tag, index); + } + CHECK_EQ(buffer_footprint_queue_.size(), 0); + if (scop_info_.user_config_.GetEnableHoistCondWrite()) { + scop_info_.CollectConditionalWritePromotions(); + } + return schedule; +} + +isl::schedule MemoryManager::HoistBufferFootprintAtMarkNode(const isl::schedule_node &root, const std::string &mark_tag, + size_t index) { + auto fn = [mark_tag, index, this](isl::schedule_node node) -> isl::schedule_node { + if (node.isa()) { + std::string mark_id = node.as().get_id().get_name(); + if (mark_id == mark_tag) { + node = HoistBufferFootprintAtMarkNode(node.get_child(0), index); + } + } + return node; + }; + + return MapDescendantTopDown(root, fn).get_schedule(); +} + +isl::schedule_node MemoryManager::HoistBufferFootprintAtMarkNode(const isl::schedule_node &tree, size_t index) { + auto schedule = LocalSchedule(tree); + + // hoist cluster and add extension to schedule tree + return HoistTensorClusterFootprint(tree, index, schedule); +} + +isl::schedule_node MemoryManager::HoistTensorClusterFootprint(isl::schedule_node tree, size_t buffered_fp_idx, + const isl::union_map &schedule) { + BufferDefInfo &tensor_info = scop_info_.analysis_result_.buffer_def_infos_[buffered_fp_idx]; + isl::union_map sch_map = scop_info_.analysis_result_.GetScheduleMapBeforeTile(); + + isl::schedule_node mark_node = tree; + if (tree.has_parent()) { + mark_node = tree.parent(); + } + + isl::id src_tensor_id = tensor_info.tensor_id; + isl::id dst_tensor_id = tensor_info.dst_tensor_id; + bool is_bind_tensor = tensor_info.is_bind_tensor; + + auto fp_cluster = tensor_info.GetFootPrintCluster(mark_node); + if ((fp_cluster == nullptr) || (!fp_cluster->foot_print_.box.is_valid())) { + LOG(INFO) << "FootprintsClusters: fp_cluster is null or box is invalid! src: " << src_tensor_id + << ", dst: " << dst_tensor_id; + return tree; + } + + auto active_domains = CollectDomain(tree); + auto active_buf_fp = CollectBufferedFootprints(active_domains, src_tensor_id); + auto foot_prints = isl::set::empty(fp_cluster->GetSingleAccessRange().get_space()); + auto all_read_only = fp_cluster->UnWriteable(); + for (const auto &buf_fp : active_buf_fp) { + foot_prints = foot_prints.unite(buf_fp.second.cluster->GetSingleAccessRange()); + all_read_only = all_read_only && buf_fp.second.cluster->UnWriteable(); + } + + if (is_bind_tensor && tensor_info.mem_type != MemType::UBL0_) { + if (!(scop_info_.cube_info_.IsGemm() && tensor_info.IsCubeCL1Write())) { + bool insert_ub_to_l1 = false; + if (!scop_info_.analysis_result_.GetFakeCopyin().is_empty()) { + scop_info_.analysis_result_.GetFakeCopyin().foreach_map( + [&insert_ub_to_l1, &src_tensor_id, &dst_tensor_id](const isl::map &m) -> void { + if ((m.get_tuple_id(isl_dim_out).get_name() == src_tensor_id.get_name()) && + (src_tensor_id.get_name() + "_local_L1" == dst_tensor_id.get_name())) { + insert_ub_to_l1 = true; + } + }); + } + if (insert_ub_to_l1) { + isl::id outer_tensorId = isl::id(src_tensor_id.ctx(), src_tensor_id.get_name() + "_local_UB"); + tree = PlaceInnerDataCopyBelow(scop_info_, tree, *fp_cluster, *fp_cluster, src_tensor_id, dst_tensor_id, + outer_tensorId, sch_map); + } else { + tree = PlaceOuterDataCopyBelow(scop_info_, tree, *fp_cluster, src_tensor_id, dst_tensor_id, sch_map, + schedule_.get_domain().get_space()); + } + } else { + buffer_footprint_queue_.push(src_tensor_id); + } + // If the new buffer_footprint is not a strict subset of any other parent + auto cluster = std::shared_ptr(std::move(fp_cluster)); + scop_info_.analysis_result_.active_buffer_footprints_.emplace_back( + std::make_pair(active_domains, BufferedFootPrintInfo{cluster, schedule, dst_tensor_id})); + tensor_info.find_buffer = true; + return tree; + } + + if (tensor_info.IsIm2col()) { + isl::id cluster_id = tensor_info.NextTensorDstId(); + auto l0_fp_cluster = GetFootPrintsCluster(dst_tensor_id); + CHECK(l0_fp_cluster != nullptr); + tree = PlaceIm2colBelow(scop_info_, tree, *l0_fp_cluster, *fp_cluster, cluster_id, dst_tensor_id); + // If the new buffer_footprint is not a strict subset of any other parent + auto cluster = std::shared_ptr(std::move(l0_fp_cluster)); + scop_info_.analysis_result_.active_buffer_footprints_.emplace_back( + std::make_pair(active_domains, BufferedFootPrintInfo{cluster, schedule, dst_tensor_id})); + tensor_info.find_buffer = true; + SetFindBuffer(dst_tensor_id, true); + return tree; + } + + if (tensor_info.IsGemmDataL12L0()) { + if (scop_info_.cube_info_.IsGemmDataTranspose()) { + const isl::id &trans_id = dst_tensor_id; + const isl::id &cluster_id = dst_tensor_id; + tree = PlaceIm2colBelow(scop_info_, tree, *gemm_a_transpose_fp_cluster_, *fp_cluster, trans_id, cluster_id); + scop_info_.analysis_result_.active_buffer_footprints_.emplace_back( + std::make_pair(active_domains, BufferedFootPrintInfo{gemm_a_transpose_fp_cluster_, schedule, cluster_id})); + } + } + + if (tensor_info.IsGemmWeightL12L0()) { + if (scop_info_.cube_info_.IsGemmWeightTranspose()) { + const isl::id &trans_id = dst_tensor_id; + const isl::id &cluster_id = dst_tensor_id; + tree = PlaceIm2colBelow(scop_info_, tree, *gemm_b_transpose_fp_cluster_, *fp_cluster, trans_id, cluster_id); + scop_info_.analysis_result_.active_buffer_footprints_.emplace_back( + std::make_pair(active_domains, BufferedFootPrintInfo{gemm_b_transpose_fp_cluster_, schedule, cluster_id})); + } + } + auto scop_cluster = fp_cluster; + if (scop_info_.cube_info_.IsGemm() && (tensor_info.IsGemmDataL12L0() || tensor_info.IsGemmWeightL12L0())) { + scop_cluster = scop_info_.analysis_result_.GetBufferDefInfo(tensor_info.tensor_id).footprints_cluster; + } + if (tensor_info.IsPreCubeTile2Write()) { + auto info = scop_info_.analysis_result_.GetBufferDefInfo(tensor_info.tensor_id); + auto new_scop_group = info.GetFootPrintCluster(mark_node); + if (new_scop_group != nullptr) { + scop_cluster = new_scop_group; + } + } + tree = PlaceInnerDataCopyBelow(scop_info_, tree, *fp_cluster, *scop_cluster, src_tensor_id, dst_tensor_id, + src_tensor_id, sch_map); + if (scop_info_.cube_info_.IsGemm() && !buffer_footprint_queue_.empty() && + buffer_footprint_queue_.front().get_name() == tensor_info.ancester_tensor_id.get_name()) { + tree = PlaceOuterDataCopyBelow(scop_info_, tree, *fp_cluster, tensor_info.ancester_tensor_id, src_tensor_id, + sch_map, schedule_.get_domain().get_space()); + buffer_footprint_queue_.pop(); + } + + // If the new buffer_footprint is not a strict subset of any other parent + auto group = std::shared_ptr(std::move(fp_cluster)); + + scop_info_.analysis_result_.active_buffer_footprints_.emplace_back( + std::make_pair(active_domains, BufferedFootPrintInfo{group, schedule, dst_tensor_id})); + tensor_info.find_buffer = true; + return tree; +} + +std::vector> MemoryManager::CollectBufferedFootprints( + const isl::union_set &active_domains, const isl::id &tensor_id) const { + std::vector> result; + + for (auto idx : CollectBufferedFootprintsIndexes(active_domains, tensor_id)) { + result.emplace_back(scop_info_.analysis_result_.active_buffer_footprints_[idx]); + } + return result; +} + +std::vector MemoryManager::CollectBufferedFootprintsIndexes(const isl::union_set &active_domains, + const isl::id &tensor_id) const { + std::vector result; + + for (size_t i = 0, e = scop_info_.analysis_result_.active_buffer_footprints_.size(); i < e; ++i) { + const auto &act_fp = scop_info_.analysis_result_.active_buffer_footprints_[i]; + if (act_fp.first.intersect(active_domains).is_empty()) { + continue; + } + + auto cluster_id = act_fp.second.cluster_id; + for (const auto &def_iter : scop_info_.analysis_result_.BufferDefInfos()) { + if (def_iter.dst_tensor_id.get_name() == cluster_id.get_name() && + def_iter.tensor_id.get_name() == tensor_id.get_name()) { + result.push_back(i); + break; + } + } + } + return result; +} + +std::shared_ptr MemoryManager::GetFootPrintsCluster(const isl::id &tensor_id) { + for (const auto &info : scop_info_.analysis_result_.buffer_def_infos_) { + if (info.tensor_id.get_name() == tensor_id.get_name()) { + return info.footprints_cluster; + } + } + return nullptr; +} + +// set the findPromote to the given tensor_id in buffered_decl_infos_ +// based on tensor_id_ +void MemoryManager::SetFindBuffer(const isl::id &tensor_id, bool find_buffer) { + for (auto &info : scop_info_.analysis_result_.buffer_def_infos_) { + if (info.tensor_id.get_name() == tensor_id.get_name()) { + info.find_buffer = find_buffer; + return; + } + } + LOG(FATAL) << "hosited tensor" << tensor_id << "has no declaration"; +} + +PartitionSingle::PartitionSingle(int times, int tile_start, int cut_m, + const std::map &fractal_int_info) { + m_times_ = times; + m_cut_m_ = cut_m; + m_fractal_int_info_ = fractal_int_info; +} + +PartitionSingle *PartitionSingle::single_ = nullptr; +int PartitionSingle::m_times_ = 0; +int PartitionSingle::m_cut_m_ = 0; +std::map PartitionSingle::m_fractal_int_info_; + +void MemoryManager::GatherBufferFootprintDefInfo(const isl::schedule_node &tree, BufferDefInfo &tensor_info) { + auto fp_cluster = tensor_info.GetFootPrintCluster(tree); + std::vector sizes; + if (fp_cluster == nullptr) { + tensor_info.AddSize(tree, sizes); + return; + } + sizes = fp_cluster->GetFixedBoxSizes(); + + isl::id tensor_id = tensor_info.tensor_id; + isl::id cluster_id = tensor_info.dst_tensor_id; + + // build a Halide Node for cluster_id + Array shapes; + for (auto i : sizes) { + shapes.push_back(Expr(static_cast(i))); + } + + Type type = scop_info_.GetDtypeOf(tensor_id); + Tensor tensor = placeholder(shapes, type, cluster_id.get_name()); + const Buffer buffer = decl_buffer(shapes, scop_info_.GetDtypeOf(tensor_id), cluster_id.get_name()); + scop_info_.user_config_.SetBind(tensor, buffer); + + tensor_info.sizes = sizes; + tensor_info.tensor = tensor; + tensor_info.data_type = type; + tensor_info.AddSize(tree, sizes); +} + +void MemoryManager::CollectBufferFootprintDefInfo(BufferDefInfo &tensor_info, const isl::union_map &schedule_prom, + const isl::schedule_node &node) { + tensor_info.footprints_cluster = TensorFootprintCluster::HoistBufferFootprintCluster( + schedule_prom, tensor_info.ancester_tensor_id, scop_info_.analysis_result_.GetReads(), + scop_info_.analysis_result_.GetCopyin(), scop_info_.analysis_result_.GetWrites(), + scop_info_.analysis_result_.GetFakeCopyin()); + if (tensor_info.footprints_cluster != nullptr) { + tensor_info.footprint_cluster_map.emplace_back(std::make_pair(node, tensor_info.footprints_cluster)); + GatherBufferFootprintDefInfo(node, tensor_info); + } +} + +void MemoryManager::HoistIm2colBufferFootprintCluster(const isl::union_map &schedule, const isl::schedule_node &node, + const int index, BufferDefInfo &tensor_info) { + im2col_fp_cluster = ConstructAffineFpCluster(scop_info_, scop_info_.analysis_result_.GetReads(), schedule.domain(), + schedule, ReferenceType::Read, AffineType::AFFINE_IM2COL); + tensor_info.footprints_cluster = + ConstructAffineFpCluster(scop_info_, scop_info_.analysis_result_.GetReads(), schedule.domain(), schedule, + ReferenceType::Read, AffineType::AFFINE_FRACTAL); + CHECK_EQ(index, 0); + CHECK(im2col_fp_cluster != nullptr) << "im2col_fp_cluster must be not null"; + CHECK(tensor_info.footprints_cluster != nullptr) << "footprint cluster in Im2col must be defined"; + tensor_info.footprint_cluster_map.emplace_back(std::make_pair(node, tensor_info.footprints_cluster)); + + if ((tensor_info.footprints_cluster->foot_print_.box.is_valid()) && (im2col_fp_cluster->foot_print_.box.is_valid())) { + GatherBufferFootprintDefInfo(node, tensor_info); + // this update info is used for spec gemm + scop_info_.cube_info_.UpdateFractalIntFirstInfo(scop_info_.cube_info_.IsConvBackpropFilter(), + im2col_fp_cluster->GetFixedBoxSizes(), + tensor_info.footprints_cluster->GetFixedBoxSizes()); + } else { + int64_t t_ci = 1; + int64_t k_h = 0; + int64_t k_w = 0; + int64_t t_h = 1; + int64_t t_w = 1; + int64_t s_h = 1; + int64_t s_w = 1; + int64_t t_ho = 1; + int64_t t_wo = 1; + int64_t c_in = 0; + LOG(INFO) << "im2col or fractal foot_print_ box is invalid."; + + Map attr_info = scop_info_.cube_info_.GetConvAttrInfo(); + auto it = attr_info.find(ATTR_CONV_KERNEL_H); + if ((it != attr_info.end()) && (*it).second.as()) k_h = (*it).second.as()->value; + it = attr_info.find(ATTR_CONV_KERNEL_W); + if ((it != attr_info.end()) && (*it).second.as()) k_w = (*it).second.as()->value; + it = attr_info.find(ATTR_CONV_STRIDE_H); + if ((it != attr_info.end()) && (*it).second.as()) s_h = (*it).second.as()->value; + it = attr_info.find(ATTR_CONV_STRIDE_W); + if ((it != attr_info.end()) && (*it).second.as()) s_w = (*it).second.as()->value; + it = attr_info.find(ATTR_CONV_TILE_H); + if ((it != attr_info.end()) && (*it).second.as()) t_h = (*it).second.as()->value; + it = attr_info.find(ATTR_CONV_TILE_W); + if ((it != attr_info.end()) && (*it).second.as()) t_w = (*it).second.as()->value; + it = attr_info.find(ATTR_CONV_FEATURE_C); + if ((it != attr_info.end()) && (*it).second.as()) c_in = (*it).second.as()->value; + + t_ho = (t_h - k_h) / s_h + 1; + t_wo = (t_w - k_w) / s_w + 1; + + bool replace_ci = false; + auto dynamic_shape = scop_info_.user_config_.GetDynamicShape(); + if (!dynamic_shape.empty()) { + for (const auto &ds : dynamic_shape) { + if (auto dsn = ds.as()) { + if (dsn->tensor_name == "CI1") { + t_ci = (int64_t)(dsn->poly_upper_bound - 1); + replace_ci = true; + } + } + } + } + if (!replace_ci) { + t_ci = (int64_t)(c_in + 15) / 16; + } + + std::vector sizes; + sizes.push_back(1); // 1 + sizes.push_back((size_t)((t_ho * t_wo + 15) / 16)); // 109 + sizes.push_back((size_t)(t_ci * k_h * k_w)); // 43648 + sizes.push_back(16); // 16 + sizes.push_back(16); // 16 + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_GMM_M] = t_ho * t_wo; // 1739 + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_BATCH] = (int64_t)sizes[0]; + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_TILE_M] = (int64_t)sizes[1]; + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_TILE_K] = (int64_t)sizes[2]; + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_M_INNER] = (int64_t)sizes[3]; + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_K_INNER] = (int64_t)sizes[4]; + GatherFractalDefInfo(node, tensor_info, sizes); + } + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_FEATURE_W] = + scop_info_.cube_info_.ExtractExprFromAttrs(ATTR_CONV_FEATURE_W); + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_PAD_LEFT] = + scop_info_.cube_info_.ExtractExprFromAttrs(ATTR_CONV_PAD_LEFT); + scop_info_.cube_info_.fractal_int_info_[ATTR_CONV_PAD_RIGHT] = + scop_info_.cube_info_.ExtractExprFromAttrs(ATTR_CONV_PAD_RIGHT); +} + +void MemoryManager::MakeMultiBufferFootprint(const isl::union_map &schedule, const isl::schedule_node &node, int &index, + BufferDefInfo &tensor_info) { + if (!scop_info_.IsCopyinTensor(tensor_info.ancester_tensor_id.get_name())) { + CollectBufferFootprintDefInfo(tensor_info, schedule, node); + } else { + if (index == 0) { + CollectBufferFootprintDefInfo(tensor_info, schedule, node); + } else { + isl::id new_dst_id = tensor_info.GetIndexDstId(scop_info_.ctx_, tensor_info.dst_tensor_id, index); + BufferDefInfo new_footprint_info = BufferDefInfo{tensor_info.tensor_id, + new_dst_id, + tensor_info.ancester_tensor_id, + tensor_info.mem_type, + tensor_info.mark_tag, + false, + tensor_info.is_bind_tensor, + tensor_info.MakeDataStream(new_dst_id), + Tensor(), + Handle(), + tensor_info.sizes, + nullptr, + isl::union_map::empty(CreateParamsSpace(scop_info_.ctx_))}; + CollectBufferFootprintDefInfo(new_footprint_info, schedule, node); + scop_info_.analysis_result_.buffer_def_infos_.push_back(new_footprint_info); + } + } +} + +void MemoryManager::AddStateTensorsDataFlow() { + // build init list + // init list TensorID input0 DDR --> L1 --> L1 --> L0A + // TensorID input1 DDR --> L0B + // TensorID input2 DDR --> UB + // TensorID output0 DDR <-- UB <-- L0C + // TensorID max_1 UB --> DDR + // build whole list + // add below node + // TensorID input0_local_L1 L1 --> L1 --> L0A + // TensorID input0_fractal_L1 L1 --> L0A + // TensorID input0_fractal_L1_local_L0A L0A + // TensorID input1_local_L1_local_L0B L0B + // TensorID output0_local_UB UB <-- L0C + // TensorID output0_local_UB_local_L0C L0C + // TensorID input2_local_UB UB + // TensorID max_1_local_UB UB + auto tensor_name_flows = scop_info_.analysis_result_.GetTensorNameFlows(); + auto tensor_mem_flows = scop_info_.analysis_result_.GetTensorMemFlows(); + CHECK_EQ(tensor_mem_flows.size(), tensor_name_flows.size()); + CHECK_GT(tensor_mem_flows.size(), 0); + for (const auto &tensor : tensor_mem_flows) { + std::string name = tensor.first; + if (tensor_name_flows.find(name) == tensor_name_flows.end()) continue; + auto it = std::find(tensor_mem_flows[name].begin(), tensor_mem_flows[name].end(), UBL1_); + auto it2 = std::find(tensor_mem_flows[name].begin(), tensor_mem_flows[name].end(), L1_); + if (it != tensor_mem_flows[name].end() && it2 != tensor_mem_flows[name].end()) { + std::vector name_flow1, name_flow2; + MemFlow mem_flow1, mem_flow2; + if (scop_info_.cube_info_.IsConv() || scop_info_.cube_info_.IsGemm()) { + name_flow1.push_back(tensor_name_flows[name][0]); + mem_flow1.push_back(tensor_mem_flows[name][0]); + name_flow1.push_back(tensor_name_flows[name][2]); + mem_flow1.push_back(tensor_mem_flows[name][2]); + name_flow1.push_back(tensor_name_flows[name][1]); + mem_flow1.push_back(tensor_mem_flows[name][1]); + + name_flow2.push_back(tensor_name_flows[name][0]); + mem_flow2.push_back(tensor_mem_flows[name][0]); + name_flow2.push_back(tensor_name_flows[name][2]); + mem_flow2.push_back(tensor_mem_flows[name][2]); + name_flow2.push_back(tensor_name_flows[name][3]); + mem_flow2.push_back(tensor_mem_flows[name][3]); + } + if (scop_info_.cube_info_.IsConv() && scop_info_.cube_info_.IsA(name)) { + name_flow2.push_back(tensor_name_flows[name][4]); + mem_flow2.push_back(tensor_mem_flows[name][4]); + } + + AddTensorDataFlow(mem_flow1, name_flow1); + AddTensorDataFlow(mem_flow2, name_flow2); + + continue; + } + AddTensorDataFlow(tensor.second, tensor_name_flows[name]); + } + + size_t length = scop_info_.analysis_result_.buffer_def_infos_.size(); + for (size_t tensor_idx = 0; tensor_idx < length; tensor_idx++) { + if (scop_info_.analysis_result_.buffer_def_infos_[tensor_idx].data_stream.size() == 1) continue; + + isl::id ancestor_id = scop_info_.analysis_result_.buffer_def_infos_[tensor_idx].tensor_id; + for (size_t idx = 1; idx < scop_info_.analysis_result_.buffer_def_infos_[tensor_idx].data_stream.size(); ++idx) { + if (idx + 1 == scop_info_.analysis_result_.buffer_def_infos_[tensor_idx].data_stream.size()) continue; + std::vector> sub_data_stream = + scop_info_.analysis_result_.buffer_def_infos_[tensor_idx].PartialDataStream(idx); + AddOneBufferDefInfo(ancestor_id, sub_data_stream); + } + } +} + +void MemoryManager::AddOneBufferDefInfo(const isl::id &ancestor_id, + const std::vector> &data_stream) { + if (data_stream.empty()) return; + + auto target = data_stream[0]; + isl::id tensor_id = target.first; + MemType mem_type = target.second; + constexpr auto TENSORLISTTAILNAME = "TensorListTail"; + isl::id dst_tensorId = isl::id(scop_info_.ctx_, TENSORLISTTAILNAME); + MemType dst_mem_type = MemType::DDR; + if (0 < data_stream.size() - 1) { + dst_tensorId = data_stream[1].first; + dst_mem_type = data_stream[1].second; + } + + MemFlow mem_flow; + for (const auto &item : data_stream) { + mem_flow.push_back(item.second); + } + std::string mark_tag = TensorMarkTag(dst_mem_type, mem_flow); + if (mark_tag.empty()) return; + + std::vector sizes; + BufferDefInfo promoted_info = BufferDefInfo{tensor_id, + dst_tensorId, + ancestor_id, + mem_type, + mark_tag, + false, + false, + data_stream, + Tensor(), + Handle(), + sizes, + nullptr, + isl::union_map::empty(isl::space(scop_info_.ctx_, 0))}; + MakeBufferFootprintCluster(promoted_info); + scop_info_.analysis_result_.buffer_def_infos_.push_back(promoted_info); +} + +void MemoryManager::AddTensorDataFlow(const std::vector &memflow, const std::vector &nameflow) { + CHECK(memflow.size() == nameflow.size()); + uint64_t i = 0; + /********************************************* + * + * init mem_type: DDR + * init tensor_id: input0 + * init dst_tensorId: input0_local_L1 + * init ancestor_id: input0 + * + * init mark_tag: base on dst_tensorId mem_type, realize_L1 + * init data_stream: input0 --> input0_local_L1 --> input0_fractal_L1 --> input0_fractal_L1_local_L0A + **********************************************/ + std::string tensor_name = nameflow[i]; + MemType mem_type = memflow[i]; + + isl::id tensor_id = isl::id(scop_info_.ctx_, tensor_name); + isl::id ancestor_id = tensor_id; + isl::id dst_tensorId = isl::id(scop_info_.ctx_, tensor_name); + if (i < nameflow.size() - 1) { + std::string dst_tensor_name = nameflow[i + 1]; + dst_tensorId = isl::id(scop_info_.ctx_, dst_tensor_name); + } + std::vector> data_stream; + + for (size_t j = i; j < nameflow.size(); j++) { + std::string tmp_name = nameflow[j]; + isl::id tmp_id = isl::id(scop_info_.ctx_, tmp_name); + MemType tmp_mem_type = memflow[j]; + data_stream.emplace_back(std::make_pair(tmp_id, tmp_mem_type)); + } + MemType dst_mem_type = MemType::DDR; + if (data_stream.size() > 1) { + dst_mem_type = data_stream[1].second; + } + std::string mark_tag = TensorMarkTag(dst_mem_type, memflow); + if (scop_info_.cube_info_.IsIm2col() && mark_tag == REALIZE_L1) { + mark_tag = REALIZE_UB; + } + + bool isCopyin = scop_info_.IsCopyinTensor(tensor_id.get_name()); + if (!isCopyin && dst_mem_type == MemType::UBL1_) { + mark_tag = REALIZE_L1UBL1; + } + + std::vector sizes; + bool is_bind_tensor = true; + BufferDefInfo promoted_info = BufferDefInfo{tensor_id, + dst_tensorId, + ancestor_id, + mem_type, + mark_tag, + false, + is_bind_tensor, + data_stream, + Tensor(), + Handle(), + sizes, + nullptr, + isl::union_map::empty(isl::space(scop_info_.ctx_, 0))}; + MakeBufferFootprintCluster(promoted_info); + scop_info_.analysis_result_.buffer_def_infos_.push_back(promoted_info); +} + +void MemoryManager::MakeBufferFootprintCluster(BufferDefInfo &tensor_info) { + std::vector nodes = CollectMarkNode(schedule_.get_root(), tensor_info.mark_tag); + int index = 0; + for (const auto &node : nodes) { + isl::schedule_node tree = node.get_child(0); + auto schedule = LocalSchedule(tree); + + // get TensorFootPrintsCluster for each tensor + if (tensor_info.IsIm2col()) { + HoistIm2colBufferFootprintCluster(schedule, node, index, tensor_info); + } else { + if (tensor_info.IsGemmDataL12L0() || tensor_info.IsGemmWeightL12L0()) { + AddGemmTransposeFpCluster(schedule); + } + MakeMultiBufferFootprint(schedule, node, index, tensor_info); + scop_info_.cube_info_.UpdateSpecGemmFractalInfo(tensor_info); + } + index++; + } +} + +void MemoryManager::ReorderBufferedDefInfos() { + if (scop_info_.analysis_result_.GetFakeCopyin().is_empty()) { + return; + } + + std::unordered_set tensors; + scop_info_.analysis_result_.GetFakeCopyin().foreach_map( + [&tensors](const isl::map &m) -> void { tensors.insert(m.get_tuple_id(isl_dim_out).get_name()); }); + + for (size_t index = 1; index < scop_info_.analysis_result_.buffer_def_infos_.size(); index++) { + if ((scop_info_.analysis_result_.buffer_def_infos_[index].mark_tag == REALIZE_L1) && + (tensors.find(scop_info_.analysis_result_.buffer_def_infos_[index].tensor_id.get_name()) != tensors.end())) { + BufferDefInfo promoted_info = scop_info_.analysis_result_.buffer_def_infos_[index]; + scop_info_.analysis_result_.buffer_def_infos_.erase(scop_info_.analysis_result_.buffer_def_infos_.begin() + + static_cast(index)); + scop_info_.analysis_result_.buffer_def_infos_.insert(scop_info_.analysis_result_.buffer_def_infos_.begin(), + promoted_info); + } + } +} + +void MemoryManager::AddGemmTransposeFpCluster(const isl::union_map &schedule) { + auto domain = schedule.domain(); + if (scop_info_.cube_info_.IsGemmDataTranspose()) { + if (scop_info_.cube_info_.IsGemmDataTransposeBlock()) { + gemm_a_transpose_fp_cluster_ = + ConstructAffineFpCluster(scop_info_, scop_info_.analysis_result_.GetReads(), domain, schedule, + ReferenceType::Read, AffineType::AFFINE_GEMMBLOCK, AffineTensor::LEFT_TENSOR); + } else if (scop_info_.cube_info_.IsGemmDataTransposeInnerBlock()) { + gemm_a_transpose_fp_cluster_ = + ConstructAffineFpCluster(scop_info_, scop_info_.analysis_result_.GetReads(), domain, schedule, + ReferenceType::Read, AffineType::AFFINE_GEMMBLOCKIN, AffineTensor::LEFT_TENSOR); + } else { + gemm_a_transpose_fp_cluster_ = + ConstructAffineFpCluster(scop_info_, scop_info_.analysis_result_.GetReads(), domain, schedule, + ReferenceType::Read, AffineType::AFFINE_GEMM, AffineTensor::LEFT_TENSOR); + } + } + if (scop_info_.cube_info_.IsGemmWeightTranspose()) { + if (scop_info_.cube_info_.IsGemmWeightTransposeBlock()) { + gemm_b_transpose_fp_cluster_ = + ConstructAffineFpCluster(scop_info_, scop_info_.analysis_result_.GetReads(), domain, schedule, + ReferenceType::Read, AffineType::AFFINE_GEMMBLOCK, AffineTensor::RIGHT_TENSOR); + } else if (scop_info_.cube_info_.IsGemmWeightTransposeInnerBlock()) { + gemm_b_transpose_fp_cluster_ = + ConstructAffineFpCluster(scop_info_, scop_info_.analysis_result_.GetReads(), domain, schedule, + ReferenceType::Read, AffineType::AFFINE_GEMMBLOCKIN, AffineTensor::RIGHT_TENSOR); + } else { + gemm_b_transpose_fp_cluster_ = + ConstructAffineFpCluster(scop_info_, scop_info_.analysis_result_.GetReads(), domain, schedule, + ReferenceType::Read, AffineType::AFFINE_GEMM, AffineTensor::RIGHT_TENSOR); + } + } +} + +void MemoryManager::GatherFractalDefInfo(const isl::schedule_node &tree, BufferDefInfo &tensor_info, + std::vector &sizes) { + isl::id tensor_id = tensor_info.tensor_id; + isl::id cluster_id = tensor_info.dst_tensor_id; + + Array shapes; + for (auto i : sizes) { + shapes.push_back(Expr(static_cast(i))); + } + + Type type = scop_info_.GetDtypeOf(tensor_id); + Tensor tensor = placeholder(shapes, type, cluster_id.get_name()); + const Buffer buffer = decl_buffer(shapes, scop_info_.GetDtypeOf(tensor_id), cluster_id.get_name()); + scop_info_.user_config_.SetBind(tensor, buffer); + + tensor_info.sizes = sizes; + tensor_info.tensor = tensor; + tensor_info.data_type = type; + tensor_info.AddSize(tree, sizes); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/memory_manager.h b/src/poly/schedule_pass/memory_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..b09235d3ebcc07e162b36e75b4f01b770e833b76 --- /dev/null +++ b/src/poly/schedule_pass/memory_manager.h @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_MEMORY_MANAGER_H_ +#define POLY_MEMORY_MANAGER_H_ + +#include +#include "poly/pass_info.h" +#include "poly/scop_info.h" +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { +class MemoryManager : public SchedulePass { + public: + explicit MemoryManager(ScopInfo &scop_info) : scop_info_(scop_info) { pass_name_ = __FUNCTION__; } + ~MemoryManager() {} + + virtual isl::schedule Run(isl::schedule sch); + + private: + isl::schedule HoistBufferFootprintAtMarkNode(const isl::schedule_node &root, const std::string &markTag, + size_t index); + isl::schedule_node HoistBufferFootprintAtMarkNode(const isl::schedule_node &tree, size_t index); + isl::schedule_node HoistTensorClusterFootprint(isl::schedule_node tree, size_t index, const isl::union_map &schedule); + std::vector> CollectBufferedFootprints( + const isl::union_set &active_points, const isl::id &tensor_id) const; + std::vector CollectBufferedFootprintsIndexes(const isl::union_set &active_points, + const isl::id &tensor_id) const; + std::shared_ptr GetFootPrintsCluster(const isl::id &tensor_id); + void SetFindBuffer(const isl::id &tensor_id, bool find_buffer); + + void AddStateTensorsDataFlow(); + void AddTensorDataFlow(const std::vector &mem_flow, const std::vector &name_flow); + + // record buffer footprint + void AddOneBufferDefInfo(const isl::id &ancestorId, const std::vector> &data_stream); + void MakeBufferFootprintCluster(BufferDefInfo &tensor_info); + void GatherBufferFootprintDefInfo(const isl::schedule_node &tree, BufferDefInfo &tensor_info); + void GatherFractalDefInfo(const isl::schedule_node &tree, BufferDefInfo &tensor_info, std::vector &sizes); + void HoistIm2colBufferFootprintCluster(const isl::union_map &schedule, const isl::schedule_node &node, int index, + BufferDefInfo &tensor_info); + void MakeMultiBufferFootprint(const isl::union_map &schedule, const isl::schedule_node &node, int &index, + BufferDefInfo &tensor_info); + void ReorderBufferedDefInfos(); + void CollectBufferFootprintDefInfo(BufferDefInfo &tensor_info, const isl::union_map &schedule, + const isl::schedule_node &node); + + void AddGemmTransposeFpCluster(const isl::union_map &schedule); + + private: + // PassInfo &pass_info_; + ScopInfo &scop_info_; + std::queue buffer_footprint_queue_; + + std::shared_ptr gemm_a_transpose_fp_cluster_; + std::shared_ptr gemm_b_transpose_fp_cluster_; + std::shared_ptr im2col_fp_cluster; + + isl::schedule schedule_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif \ No newline at end of file diff --git a/src/poly/schedule_pass/reorder_inner_band.cc b/src/poly/schedule_pass/reorder_inner_band.cc new file mode 100644 index 0000000000000000000000000000000000000000..a83c048f4d8a8ecdffbc33fb69deb9ae4dc78917 --- /dev/null +++ b/src/poly/schedule_pass/reorder_inner_band.cc @@ -0,0 +1,176 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "reorder_inner_band.h" + +namespace akg { +namespace ir { +namespace poly { + +std::vector ExtractDimNames(const isl::aff &aff) { + std::vector dim_names; + int dims = isl_aff_dim(aff.get(), isl_dim_in); + CHECK_GE(dims, 0); + for (int i = 0; i < dims; ++i) { + isl_val *coef_val = isl_aff_get_coefficient_val(aff.get(), isl_dim_in, i); + int coef = isl_val_get_num_si(coef_val); + static_cast(isl_val_free(coef_val)); + if (coef != 0) { + auto dim_name = std::string(isl_aff_get_dim_name(aff.get(), isl_dim_in, i)); + dim_names.push_back(dim_name); + } + } + return dim_names; +} + +isl::multi_union_pw_aff MergeTwoUPALs(const isl::multi_union_pw_aff &partial_schedule, + const std::vector &dims_with_if, + const std::vector &dims_without_if) { + auto num_dims_with_if = dims_with_if.size(); + auto num_dims_without_if = dims_without_if.size(); + CHECK(partial_schedule.size() == num_dims_with_if + num_dims_without_if); + auto new_schedule = partial_schedule; + for (unsigned dim = 0; dim < num_dims_with_if; ++dim) { + new_schedule = new_schedule.set_at(dim, dims_with_if[dim]); + } + for (unsigned dim = 0; dim < num_dims_without_if; ++dim) { + new_schedule = new_schedule.set_at(dim + num_dims_with_if, dims_without_if[dim]); + } + return new_schedule; +} + +void MergeTwoDimMaps(const std::vector &in1, const std::vector &in2, std::vector &out) { + out.resize(in1.size() + in2.size()); + for (unsigned i = 0; i < in1.size(); ++i) { + out[in1[i]] = i; + } + for (unsigned i = 0; i < in2.size(); ++i) { + out[in2[i]] = i + in1.size(); + } +} + +/* + * Reorder the partial schedule such that isl::union_pw_aff with range var in cond_vars are + * ordered before others. + * + * Example: + * [{ S_0[j, k, l] -> [((k) mod 16)] }, + * { S_0[j, k, l] -> [((l) mod 16)] }, + * { S_0[j, k, l] -> [(0)] }, + * { S_0[j, k, l] -> [((j) mod 32)] }] + * + * If "j" appears in the conditions, the partial schedule is transformed to: + * + * [{ S_0[j, k, l] -> [((j) mod 32)] }, + * { S_0[j, k, l] -> [((k) mod 16)] }, + * { S_0[j, k, l] -> [((l) mod 16)] }, + * { S_0[j, k, l] -> [(0)] }] + */ +isl::multi_union_pw_aff ReorderLocalSchedule(const CondVarsMap &cond_vars, + const isl::multi_union_pw_aff &partial_schedule, + std::vector &dim_map, bool &need_update) { + std::vector dims_with_if, dims_without_if; + std::vector with_if_dim_map, without_if_dim_map; + need_update = false; + auto original_upal = partial_schedule.union_pw_aff_list(); + unsigned upal_size = original_upal.size(); + for (unsigned dim = 0; dim < upal_size; ++dim) { + auto dim_schedule = original_upal.get_at(dim); + bool found_dim_in_cond = false; + dim_schedule.get_pw_aff_list().foreach([&cond_vars, &found_dim_in_cond](const isl::pw_aff &stmt_schedule) -> void { + stmt_schedule.foreach_piece([&cond_vars, &found_dim_in_cond](const isl::set &set, const isl::aff &aff) -> void { + isl::id stmt_id = set.get_tuple_id(); + if (cond_vars.count(stmt_id) == 0) return; + const auto &cond_vars_in_stmt = cond_vars.at(stmt_id); + auto dim_names = ExtractDimNames(aff); + for (const auto &dim_name : dim_names) { + if (cond_vars_in_stmt.count(dim_name)) found_dim_in_cond = true; + } + }); + }); + if (found_dim_in_cond) { + with_if_dim_map.push_back(dim); + dims_with_if.push_back(dim_schedule); + need_update = true; + } else { + without_if_dim_map.push_back(dim); + dims_without_if.push_back(dim_schedule); + } + } + + if (need_update) { + MergeTwoDimMaps(with_if_dim_map, without_if_dim_map, dim_map); + return MergeTwoUPALs(partial_schedule, dims_with_if, dims_without_if); + } else { + return partial_schedule; + } +} + +/* + * isl::schedule_node_band does not provide an interface to update the partial schedule, + * so we have to delete the band node, copy other attributes from the original node and + * insert the new partial schedule. + * + * Note that the member coincident values needs to be updated according to the mapping + * between original and new dims. + */ +isl::schedule_node setLocalSchedule(const isl::schedule_node_band &band, + const isl::multi_union_pw_aff &partial_schedule, + const std::vector &dim_map) { + auto removed_band = band.del(); + auto new_band_obj = removed_band.insert_partial_schedule(partial_schedule); + auto new_band = new_band_obj.copy(); + new_band = isl_schedule_node_band_set_permutable(new_band, band.get_permutable()); + auto ast_build_options = band.get_ast_build_options(); + new_band = isl_schedule_node_band_set_ast_build_options(new_band, ast_build_options.copy()); + unsigned n_member = band.n_member(); + CHECK(dim_map.size() == n_member); + for (unsigned i = 0; i < n_member; ++i) { + bool coincident = band.member_get_coincident(i); + unsigned new_member = dim_map[i]; + new_band = isl_schedule_node_band_member_set_coincident(new_band, new_member, coincident); + } + return isl::manage(new_band); +} + +isl::schedule_node RewriteLeafBandNode(const CondVarsMap &cond_vars, const isl::schedule_node_band &band) { + auto partial_schedule = band.get_partial_schedule(); + std::vector dim_map; + bool need_update = false; + auto new_partial_schedule = ReorderLocalSchedule(cond_vars, partial_schedule, dim_map, need_update); + if (!need_update) + return band; + else + return setLocalSchedule(band, new_partial_schedule, dim_map); +} + +isl::schedule ReorderInnerBand::Run(isl::schedule curr_schedule) { + isl::schedule_node root = curr_schedule.get_root(); + auto cond_vars = cond_vars_; + root = root.map_descendant_bottom_up([&cond_vars](const isl::schedule_node &node) -> isl::schedule_node { + bool is_leaf_band = + node.as() && node.n_children() == 1 && node.first_child().as(); + if (!is_leaf_band) return node; + + auto band = node.as(); + if (!band.get_permutable()) return node; + return RewriteLeafBandNode(cond_vars, band); + }); + return root.get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/reorder_inner_band.h b/src/poly/schedule_pass/reorder_inner_band.h new file mode 100644 index 0000000000000000000000000000000000000000..1a483b495bf645cc6707c7aa2605d7f6e2b88c5e --- /dev/null +++ b/src/poly/schedule_pass/reorder_inner_band.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_REORDER_INNER_BAND_H_ +#define POLY_REORDER_INNER_BAND_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Reorder the members of the leaf-band partial schedule (if it is permutable) + * such that loop vars that appear in "if" conditions are the outer loops. + * This aims to promote the "if" condition to the outermost loop, and maximize + * the size of unconditional vectorized computation. + */ +class ReorderInnerBand : public SchedulePass { + public: + ReorderInnerBand(const CondVarsMap &cond_vars) : cond_vars_(cond_vars) { pass_name_ = __FUNCTION__; }; + ~ReorderInnerBand(){}; + + virtual isl::schedule Run(isl::schedule sch); + + private: + CondVarsMap cond_vars_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_REORDER_INNER_BAND_H_ diff --git a/src/poly/schedule_pass/reorder_invariant_set_schedule.cc b/src/poly/schedule_pass/reorder_invariant_set_schedule.cc new file mode 100644 index 0000000000000000000000000000000000000000..f37e8a790388f389a829fb6675c5d0b9cd7a0f5a --- /dev/null +++ b/src/poly/schedule_pass/reorder_invariant_set_schedule.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "reorder_invariant_set_schedule.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule ReorderInvariantSetSchedule::Run(isl::schedule sch) { + if (!pass_info_.has_invariant_dependence_) { + return sch; + } + isl::schedule_node root = sch.get_root(); + isl::schedule_node outer_band = GetOuterBand(root); + if (outer_band.isa()) { + std::vector new_filters; + std::vector invariant_filters; + std::vector rest_filters; + for (unsigned int i = 0; i < outer_band.n_children(); ++i) { + isl::schedule_node node = outer_band.get_child(i); + auto filter = node.as(); + isl::union_set sets = filter.get_filter(); + unsigned int invariant_count = 0; + sets.foreach_set([&invariant_count, this](const isl::set &s) -> void { + if (s.n_dim() == 0 && this->pass_info_.invariant_state_.count(s.get_tuple_name()) > 0) { + invariant_count++; + } + }); + + if (invariant_count == sets.n_set()) { + invariant_filters.push_back(i); + } else { + rest_filters.push_back(i); + } + } + + for (unsigned long &invariant_filter : invariant_filters) { + new_filters.push_back(invariant_filter); + } + + for (unsigned long &rest_filter : rest_filters) { + new_filters.push_back(rest_filter); + } + + std::unordered_map old_to_new_map; + for (size_t i = 0; i < new_filters.size(); ++i) { + old_to_new_map.emplace(new_filters[i], i); + } + + outer_band = ReorderFilters(outer_band, old_to_new_map); + } + return outer_band.get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/reorder_invariant_set_schedule.h b/src/poly/schedule_pass/reorder_invariant_set_schedule.h new file mode 100644 index 0000000000000000000000000000000000000000..348bc3ef705c4f636983e0b9c1df3f1768fab2bd --- /dev/null +++ b/src/poly/schedule_pass/reorder_invariant_set_schedule.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_REORDER_INVARIANT_SET_SCHEDULE_H_ +#define POLY_REORDER_INVARIANT_SET_SCHEDULE_H_ + +#include "poly/schedule_pass.h" +#include "poly/pass_info.h" + +namespace akg { +namespace ir { +namespace poly { + +class ReorderInvariantSetSchedule : public SchedulePass { + public: + ReorderInvariantSetSchedule(PassInfo &pass_info) : pass_info_(pass_info) { pass_name_ = __FUNCTION__; } + ~ReorderInvariantSetSchedule() {} + + virtual isl::schedule Run(isl::schedule sch); + + private: + PassInfo &pass_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_REORDER_INVARIANT_SET_SCHEDULE_H_ \ No newline at end of file diff --git a/src/poly/schedule_pass/reorder_mark_nodes.cc b/src/poly/schedule_pass/reorder_mark_nodes.cc new file mode 100644 index 0000000000000000000000000000000000000000..dee20fb5276ba472ce693d2bca67b2474a8addf5 --- /dev/null +++ b/src/poly/schedule_pass/reorder_mark_nodes.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "reorder_mark_nodes.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule ReorderMarkNodes::Run(isl::schedule schedule_mark) { + auto fn = [](isl::schedule_node node) -> isl::schedule_node { + if (node.isa()) { + // mark node cannot be inserted between sequence node and its filter children, so skip reorder + if (node.get_child(0).as()) return node; + + std::string mark_id = node.as().get_id().get_name(); + size_t pos = mark_id.find(REALIZE_); + if (pos != std::string::npos) { + node = node.del(); + node = node.get_child(0); + node = node.insert_mark(isl::id(node.ctx(), mark_id)); + node = node.parent(); + } + } + return node; + }; + return schedule_mark.get_root().map_descendant_bottom_up(fn).get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/reorder_mark_nodes.h b/src/poly/schedule_pass/reorder_mark_nodes.h new file mode 100644 index 0000000000000000000000000000000000000000..325812e84dffa933254a561feab0de0864a1f215 --- /dev/null +++ b/src/poly/schedule_pass/reorder_mark_nodes.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_REORDER_MARK_NODES_H_ +#define POLY_REORDER_MARK_NODES_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +class ReorderMarkNodes : public SchedulePass { + public: + ReorderMarkNodes() { pass_name_ = __FUNCTION__; }; + ~ReorderMarkNodes(){}; + + virtual isl::schedule Run(isl::schedule schedule_mark); +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_REORDER_MARK_NODES_H_ diff --git a/src/poly/reschedule.cc b/src/poly/schedule_pass/reschedule.cc similarity index 88% rename from src/poly/reschedule.cc rename to src/poly/schedule_pass/reschedule.cc index 969160905fce6a9ff31f0cb5184ba3fea072e283..3b2ed718b602355970412bc85a880d9d7c845fa2 100644 --- a/src/poly/reschedule.cc +++ b/src/poly/schedule_pass/reschedule.cc @@ -14,24 +14,16 @@ * limitations under the License. */ -#include "poly/reschedule.h" +#include "reschedule.h" -#include -#include -#include - -#include #include -#include -#include - #include "poly/dump_log.h" namespace akg { namespace ir { namespace poly { -bool Transform::IsL1OrUbMark(const isl::schedule_node &node) { +bool Reschedule::IsL1OrUbMark(const isl::schedule_node &node) { if (node.isa()) { auto tag = node.as().get_id().get_name(); if (tag == REALIZE_L1 || tag == REALIZE_UB) return true; @@ -39,7 +31,7 @@ bool Transform::IsL1OrUbMark(const isl::schedule_node &node) { return false; } -bool Transform::IsL0OrUbL0Mark(const isl::schedule_node &node) { +bool Reschedule::IsL0OrUbL0Mark(const isl::schedule_node &node) { if (node.isa()) { auto tag = node.as().get_id().get_name(); if (tag == REALIZE_L0 || tag == REALIZE_UBL0) return true; @@ -56,7 +48,7 @@ bool Transform::IsL0OrUbL0Mark(const isl::schedule_node &node) { * options to "l0_build_options_" since we need to retrieve it after building * the whole schedule. */ -void Transform::CollectTileBandData(const isl::schedule_node &node, struct TileBandData *tile_band_data) { +void Reschedule::CollectTileBandData(const isl::schedule_node &node, struct TileBandData *tile_band_data) { CHECK(node.isa()) << "has to be a band node" << std::endl; tile_band_data->l0_tiled = false; @@ -85,7 +77,7 @@ void Transform::CollectTileBandData(const isl::schedule_node &node, struct TileB * options could be retrieved directly when "node" is an L1/UB tile * band, since the schedule tree is not anchored. */ -isl::schedule_node Transform::RetrieveTileBandData(isl::schedule_node node, struct TileBandData *tile_band_data) { +isl::schedule_node Reschedule::RetrieveTileBandData(isl::schedule_node node, struct TileBandData *tile_band_data) { node = node.insert_partial_schedule(tile_band_data->mupa); CHECK(node.isa()) << "node has to be a band node" << std::endl; node = node.as().set_permutable(static_cast(tile_band_data->permutable)); @@ -108,8 +100,8 @@ isl::schedule_node Transform::RetrieveTileBandData(isl::schedule_node node, stru return node; } -isl::schedule_node Transform::RetrieveNodeList(isl::schedule_node node, - const std::vector &node_list) { +isl::schedule_node Reschedule::RetrieveNodeList(isl::schedule_node node, + const std::vector &node_list) { auto n = static_cast(node_list.size()); if (!n) return node; @@ -151,8 +143,8 @@ isl::schedule_node Transform::RetrieveNodeList(isl::schedule_node node, return node; } -isl::schedule_node Transform::RetrieveAstBuildOptions(isl::schedule_node node, const isl::union_set &options) { - node = Transform::GetOuterBand(node); +isl::schedule_node Reschedule::RetrieveAstBuildOptions(isl::schedule_node node, const isl::union_set &options) { + node = GetOuterBand(node); if (node.isa()) { node = node.as().set_ast_build_options(options); return node; @@ -230,34 +222,6 @@ void ConstructNewOrder(std::unordered_map &map) { map = new_order; } -/* Reorder filters of a sequence/set node. - * node: must be a sequence or set node. - * old_to_new_map: map from original child position to new child position. - * The caller should make sure that there are no duplicate values. - */ -isl::schedule_node ReorderFilters(const isl::schedule_node &node, - const std::unordered_map &old_to_new_map) { - auto n_children = node.n_children(); - isl_schedule_tree *old_tree = isl_schedule_node_get_tree(node.get()); - CHECK(old_tree != nullptr); - isl_schedule_tree *new_tree = isl_schedule_node_get_tree(node.get()); - CHECK(new_tree != nullptr); - for (auto &it : old_to_new_map) { - auto old_pos = it.first; - auto new_pos = it.second; - CHECK(old_pos < n_children); - CHECK(new_pos < n_children); - isl_schedule_tree *old_child = isl_schedule_tree_get_child(old_tree, old_pos); - CHECK(old_child != nullptr); - new_tree = isl_schedule_tree_replace_child(new_tree, new_pos, old_child); - CHECK(new_tree != nullptr); - } - static_cast(isl_schedule_tree_free(old_tree)); - isl_schedule_node *new_node = isl_schedule_node_graft_tree(node.copy(), new_tree); - CHECK(new_node != nullptr); - return isl::manage(new_node); -} - // Restore the order of filter nodes. isl::schedule_node RestoreOrderOfFilters(const isl::schedule_node &node, const std::vector &order) { std::unordered_map id_to_order_map; @@ -297,18 +261,16 @@ isl::schedule_node RestoreOrderOfSequenceNodes(isl::schedule_node node, return node; } -bool Transform::ValidateReorderedSchedule(const isl::schedule &new_schedule) { - auto backup_schedule = schedule_; - schedule_ = new_schedule; - isl::union_map new_dependence = ComputeAllDependences(); - bool is_valid = new_dependence.is_subset(dependences_); - schedule_ = backup_schedule; +bool Reschedule::ValidateReorderedSchedule(const isl::schedule &new_schedule) { + isl::union_map new_dependence = ComputeAllDependences(new_schedule, scop_info_.analysis_result_.GetReads(), + scop_info_.analysis_result_.GetWrites()); + bool is_valid = new_dependence.is_subset(pass_info_.dependences_); return is_valid; } -isl::schedule_node Transform::TryRestoreStmtOrder(const isl::schedule_node &node, - const std::vector &filter_total_order, - const std::vector> &filter_partial_order) { +isl::schedule_node Reschedule::TryRestoreStmtOrder(const isl::schedule_node &node, + const std::vector &filter_total_order, + const std::vector> &filter_partial_order) { if (filter_total_order.empty()) return node; if (filter_partial_order.empty()) return node; @@ -328,12 +290,12 @@ isl::schedule_node Transform::TryRestoreStmtOrder(const isl::schedule_node &node } // Loop distribution by serializing sccs -isl::schedule Transform::RescheduleSerializeSccs(const isl::union_set &active_domain, const bool need_dist) { - auto ctx = constraints_.ctx(); +isl::schedule Reschedule::RescheduleSerializeSccs(const isl::union_set &active_domain, const bool need_dist) const { + auto ctx = pass_info_.constraints_.ctx(); auto wasSerializingSccs = isl_options_get_schedule_serialize_sccs(ctx.get()); isl_stat status = isl_options_set_schedule_serialize_sccs(ctx.get(), static_cast(need_dist)); CHECK(status == isl_stat_ok); - auto constraints = constraints_.intersect_domain(active_domain); + auto constraints = pass_info_.constraints_.intersect_domain(active_domain); auto new_schedule = constraints.compute_schedule(); status = isl_options_set_schedule_serialize_sccs(ctx.get(), wasSerializingSccs); CHECK(status == isl_stat_ok); @@ -341,8 +303,9 @@ isl::schedule Transform::RescheduleSerializeSccs(const isl::union_set &active_do } // Save ordering of filter children, and restore the ordering after reschedule -isl::schedule_node Transform::ReschedulePreserveFilterOrder(const isl::schedule_node &node, - const isl::union_set &active_domain, const bool need_dist) { +isl::schedule_node Reschedule::ReschedulePreserveFilterOrder(const isl::schedule_node &node, + const isl::union_set &active_domain, + const bool need_dist) { auto filter_total_order = GetStmtTotalOrdering(node); auto filter_partial_order = GetStmtPartialOrdering(node); @@ -357,7 +320,7 @@ isl::schedule_node Transform::ReschedulePreserveFilterOrder(const isl::schedule_ } // Save partial schedule, permutable and coincident attrs of a band. -PointBandInfo Transform::SavePointBand(const isl::schedule_node &node) { +PointBandInfo Reschedule::SavePointBand(const isl::schedule_node &node) { PointBandInfo point_band_info; CHECK(node.isa()); auto band = node.as(); @@ -373,7 +336,7 @@ PointBandInfo Transform::SavePointBand(const isl::schedule_node &node) { /* Restore saved partial schedule, permutable and coincident attrs of a band. * Input must be a band node. */ -isl::schedule_node Transform::SetPointBandInfo(isl::schedule_node node, const PointBandInfo &point_band_info) { +isl::schedule_node Reschedule::SetPointBandInfo(isl::schedule_node node, const PointBandInfo &point_band_info) { node = node.del(); node = node.insert_partial_schedule(point_band_info.mupa); auto n = node.as().n_member(); @@ -388,7 +351,7 @@ isl::schedule_node Transform::SetPointBandInfo(isl::schedule_node node, const Po /* Restore saved partial schedule, permutable and coincident attrs of each band in the node. * Input may be a sequence, set or band node. */ -isl::schedule_node Transform::RestorePointBandInfo(isl::schedule_node node, const PointBandInfo &point_band_info) { +isl::schedule_node Reschedule::RestorePointBandInfo(isl::schedule_node node, const PointBandInfo &point_band_info) { // Retrieve point band if a sequence/set node is introduced if (IsSequenceOrSet(node)) { // Update point band for each scc filter @@ -490,7 +453,7 @@ isl::schedule_node Transform::RestorePointBandInfo(isl::schedule_node node, cons * * Return the root of the schedule after rescheduling. */ -isl::schedule_node Transform::RescheduleSchTree(const isl::schedule_node &root) { +isl::schedule_node Reschedule::RescheduleSchTree(const isl::schedule_node &root) { bool need_dist = true; // Return "root" if given an inappropriate node if (!root.isa() && !root.isa()) return root; @@ -556,7 +519,7 @@ isl::schedule_node Transform::RescheduleSchTree(const isl::schedule_node &root) return node.get_schedule().get_root(); } - auto scalar_filter = [](isl::schedule_node node) { + auto scalar_filter = [](const isl::schedule_node &node) { if (!node.isa()) { return false; } @@ -564,7 +527,7 @@ isl::schedule_node Transform::RescheduleSchTree(const isl::schedule_node &root) auto filter = node.as(); isl::union_set sets = filter.get_filter(); bool scalar = true; - sets.foreach_set([&scalar](const isl::set s) -> void { + sets.foreach_set([&scalar](const isl::set &s) -> void { if (s.n_dim() > 0) { scalar = false; } @@ -753,12 +716,12 @@ static isl::schedule_node IslScheduleNodeReplaceChild(const isl::schedule_node & /* Reschedule the subtree of each mark node for loop distribution. * - * Transform::Reschedule assumes the mark nodes are the outer bands. + * Reschedule::Reschedule assumes the mark nodes are the outer bands. * This function do not have the assumption, so it supports tiled inner bands. * * Assume mark nodes are not nested, so this is only suitable for vector ops. */ -isl::schedule_node Transform::RescheduleInnerBand(const isl::schedule_node &root) { +isl::schedule_node Reschedule::RescheduleInnerBand(const isl::schedule_node &root) { return root.map_descendant_bottom_up([this](const isl::schedule_node &node) -> isl::schedule_node { if (!IsL1OrUbMark(node) && !IsL0OrUbL0Mark(node)) return node; @@ -780,6 +743,46 @@ isl::schedule_node Transform::RescheduleInnerBand(const isl::schedule_node &root }); } +void Reschedule::Dump() { + std::ofstream of; + of.open("transform.log", std::ios::out); + if (!of.is_open()) { + return; + } + PrintHeader(of, "L1/UB tile band build options"); + for (const auto &option : l1_build_options_) { + of << option << std::endl; + } + + PrintHeader(of, "L0 tile band build options"); + for (const auto &option : l0_build_options_) { + of << option << std::endl; + } + + PrintHeader(of, "nodes from root to L1/UB band"); + for (const auto &node : node_list_0_) { + of << node << std::endl; + } + + PrintHeader(of, "nodes from L1/UB band to L0/UBL0 band"); + for (const auto &node : node_list_1_) { + of << node << std::endl; + } + + PrintHeader(of, "nodes from L0/UBL0 band to point band"); + for (const auto &node : node_list_2_) { + of << node << std::endl; + } +} +isl::schedule Reschedule::Run(isl::schedule curr_schedule) { + isl::schedule sched = curr_schedule; + isl::schedule_node root = sched.get_root(); + if (scop_info_.user_config_.GetTileInnerBand()) + sched = RescheduleInnerBand(root).get_schedule(); + else + sched = RescheduleSchTree(root).get_schedule(); + return sched; +} } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/schedule_pass/reschedule.h b/src/poly/schedule_pass/reschedule.h new file mode 100644 index 0000000000000000000000000000000000000000..c74c2916e2cb6c90e549008469c0d031343ee2a1 --- /dev/null +++ b/src/poly/schedule_pass/reschedule.h @@ -0,0 +1,104 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_RESCHEDULE_H_ +#define POLY_RESCHEDULE_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +struct PointBandInfo { + isl::multi_union_pw_aff mupa; + size_t n_member{0}; + bool permutable{false}; + std::vector coincident; +}; + +// data structure for recording tile band data +struct TileBandData { + // flag indicating whether L0 tiled + bool l0_tiled; + // mark node of the tile band, if any + isl::schedule_node mark; + // mark node of conv_gemm, if any + isl::schedule_node gemm_mark; + // members of tile band + unsigned int n_member; + // schedule mupa + isl::multi_union_pw_aff mupa; + // permutable + bool permutable; + // coincident + std::vector coincident; + // ast build options + isl::union_set ast_build_options; +}; + +class Reschedule : public SchedulePass { + public: + Reschedule(ScopInfo &scop_info, PassInfo &pass_info) : scop_info_(scop_info), pass_info_(pass_info) { + pass_name_ = __FUNCTION__; + }; + ~Reschedule() {} + + virtual isl::schedule Run(isl::schedule sch); + + private: + static bool IsL1OrUbMark(const isl::schedule_node &node); + static bool IsL0OrUbL0Mark(const isl::schedule_node &node); + void CollectTileBandData(const isl::schedule_node &node, TileBandData *tile_band_data); + static isl::schedule_node RetrieveTileBandData(isl::schedule_node node, TileBandData *tile_band_data); + static isl::schedule_node RetrieveNodeList(isl::schedule_node node, const std::vector &node_list); + static isl::schedule_node RetrieveAstBuildOptions(isl::schedule_node node, const isl::union_set &options); + bool ValidateReorderedSchedule(const isl::schedule &new_schedule); + isl::schedule_node TryRestoreStmtOrder(const isl::schedule_node &node, const std::vector &filter_total_order, + const std::vector> &filter_partial_order); + isl::schedule RescheduleSerializeSccs(const isl::union_set &active_domain, const bool need_dist) const; + isl::schedule_node ReschedulePreserveFilterOrder(const isl::schedule_node &node, const isl::union_set &active_domain, + const bool need_dist); + static PointBandInfo SavePointBand(const isl::schedule_node &node); + static isl::schedule_node SetPointBandInfo(isl::schedule_node node, const PointBandInfo &point_band_info); + static isl::schedule_node RestorePointBandInfo(isl::schedule_node node, const PointBandInfo &point_band_info); + isl::schedule_node RescheduleSchTree(const isl::schedule_node &root); + isl::schedule_node RescheduleInnerBand(const isl::schedule_node &root); + void Dump(); + + private: + ScopInfo &scop_info_; + PassInfo &pass_info_; + // for recording L1/UB tile band build options + std::vector l1_build_options_; + + // for recording L0 tile band build options + std::vector l0_build_options_; + + // for recording nodes along the path from root to L1/UB band + std::vector node_list_0_; + + // for recording nodes along the path from L1/UB band to L0/UBL0 band + std::vector node_list_1_; + + // for recording nodes along the path from L0/UBL0 band to point band + std::vector node_list_2_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_RESCHEDULE_H_ diff --git a/src/poly/schedule_pass/reset_coincidence_of_reduce.cc b/src/poly/schedule_pass/reset_coincidence_of_reduce.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f76314ec054f1b5cd28bde94adde4c4e15906aa --- /dev/null +++ b/src/poly/schedule_pass/reset_coincidence_of_reduce.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "reset_coincidence_of_reduce.h" + +namespace akg { +namespace ir { +namespace poly { +bool ResetCoincidenceOfReduce::IsStmtScheduleContainsReduceAxis( + const isl::pw_aff &stmt, const std::unordered_set &reduce_axis_list) { + int num_dims = stmt.domain().n_dim(); + isl_space *domain_space = stmt.domain().get_space().get(); + for (int dim = 0; dim < num_dims; ++dim) { + const char *axis_name = isl_space_get_dim_name(domain_space, isl_dim_out, dim); + if (axis_name == nullptr) continue; + if (reduce_axis_list.count(axis_name) == 0) continue; + if (isl_pw_aff_involves_dims(stmt.get(), isl_dim_in, dim, 1)) return true; + } + return false; +} + +bool ResetCoincidenceOfReduce::IsDimScheduleContainsReduceAxis(const isl::union_pw_aff &schedule) { + auto reduce_stmts = pass_info_.reduce_stmts_; + bool found_reduce_axis = false; + auto stmt_list = schedule.get_pw_aff_list(); + stmt_list.foreach([&found_reduce_axis, &reduce_stmts, this](const isl::pw_aff &stmt) -> void { + isl::id stmt_id = stmt.domain().get_tuple_id(); + if (reduce_stmts.count(stmt_id)) { + std::unordered_set reduce_axis_list; + for (const auto &axis : reduce_stmts.at(stmt_id)) { + reduce_axis_list.insert(axis); + } + if (IsStmtScheduleContainsReduceAxis(stmt, reduce_axis_list)) { + found_reduce_axis = true; + } + } + }); + return found_reduce_axis; +} + +isl::schedule ResetCoincidenceOfReduce::Run(isl::schedule curr_schedule) { + pass_info_.reduce_stmts_ = scop_info_.analysis_result_.GetReduceStmtMap(); + + const auto &new_schedule = curr_schedule; + const auto &reduce_stmts = pass_info_.reduce_stmts_; + auto fn = [&reduce_stmts, this](isl::schedule_node node) -> isl::schedule_node { + if (auto band = node.as()) { + int num_dims = static_cast(band.n_member()); + for (int dim = 0; dim < num_dims; ++dim) { + bool is_coincident = band.member_get_coincident(dim); + if (!is_coincident) continue; + auto dim_schedule = band.get_partial_schedule().get_union_pw_aff(dim); + if (IsDimScheduleContainsReduceAxis(dim_schedule)) { + LOG(INFO) << "reset coincidence of reduce axis on dim " << dim << " in partial schedule: " << dim_schedule; + node = band.member_set_coincident(dim, false); + band = node.as(); + } + } + } + return node; + }; + return new_schedule.get_root().map_descendant_bottom_up(fn).get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/reset_coincidence_of_reduce.h b/src/poly/schedule_pass/reset_coincidence_of_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..23989b8e47e06f874ef67783c346727416c00f68 --- /dev/null +++ b/src/poly/schedule_pass/reset_coincidence_of_reduce.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_RESET_COINCIDENCE_OF_REDUCE_H_ +#define POLY_RESET_COINCIDENCE_OF_REDUCE_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Reset ths coincidence of reduce axis in partial schedule to 0 if the original coincidence is 1. + * This transform can prevent reduce axes from being parallelled. + */ +class ResetCoincidenceOfReduce : public SchedulePass { + public: + ResetCoincidenceOfReduce(ScopInfo &scop_info, PassInfo &pass_info) : scop_info_(scop_info), pass_info_(pass_info) { + pass_name_ = __FUNCTION__; + }; + ~ResetCoincidenceOfReduce(){}; + + virtual isl::schedule Run(isl::schedule sch); + + private: + ScopInfo &scop_info_; + PassInfo &pass_info_; + + bool IsStmtScheduleContainsReduceAxis(const isl::pw_aff &stmt, + const std::unordered_set &reduce_axis_list); + + bool IsDimScheduleContainsReduceAxis(const isl::union_pw_aff &schedule); +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_RESET_COINCIDENCE_OF_REDUCE_H_ diff --git a/src/poly/rm_self_dep.cc b/src/poly/schedule_pass/rm_self_dep.cc similarity index 88% rename from src/poly/rm_self_dep.cc rename to src/poly/schedule_pass/rm_self_dep.cc index 61adf39f9cefa634692183ae936a5e297abd0a46..bf79d5e1ef97723e24d1019a8c11df86eb998d06 100644 --- a/src/poly/rm_self_dep.cc +++ b/src/poly/schedule_pass/rm_self_dep.cc @@ -14,17 +14,16 @@ * limitations under the License. */ -#include "poly/rm_self_dep.h" +#include "rm_self_dep.h" #include -#include #include -#include #include #include #include +#include "poly/schedule_pass.h" #include "poly/dump_log.h" namespace akg { @@ -646,44 +645,46 @@ static bool IsMultiAxisSelfDependence(const isl::union_map &dependences, const i /* * Removes self dependence of (multi-axis) reduce operations. */ -isl::union_map Transform::RemoveReduceOpSelfDependence(bool multiAxisOnly) { - isl::union_map preserved_dependences = isl::union_map::empty(dependences_.get_space()); +isl::union_map RemoveReduceOpSelfDependence(ScopInfo &scop_info, PassInfo &pass_info) { + isl::union_map preserved_dependences = isl::union_map::empty(pass_info.dependences_.get_space()); std::unordered_map is_tuple_reduce_op; - dependences_.foreach_map([&, this](const isl::map m) -> void { - if (m.domain().get_tuple_id() != m.range().get_tuple_id()) { - preserved_dependences = preserved_dependences.add_map(m); - } else { // self dependence - isl::id tuple_id = m.domain().get_tuple_id(); - std::string tuple_id_key = tuple_id.get_name(); - if (is_tuple_reduce_op.count(tuple_id_key) == 0) { - std::vector reduce_axis_list; - if (multiAxisOnly && !IsMultiAxisSelfDependence(dependences_, tuple_id)) { - is_tuple_reduce_op[tuple_id_key] = false; - } else { - is_tuple_reduce_op[tuple_id_key] = - CheckIsStmtReduceOp(scop_.data_.reads, scop_.data_.writes, tuple_id, reduce_axis_list) || - CheckIsStmtReduceOp(dependences_, tuple_id, reduce_axis_list); + pass_info.dependences_.foreach_map( + [&scop_info, &pass_info, &preserved_dependences, &is_tuple_reduce_op](const isl::map &m) -> void { + if (m.domain().get_tuple_id() != m.range().get_tuple_id()) { + preserved_dependences = preserved_dependences.add_map(m); + } else { // self dependence + isl::id tuple_id = m.domain().get_tuple_id(); + std::string tuple_id_key = tuple_id.get_name(); + if (is_tuple_reduce_op.count(tuple_id_key) == 0) { + std::vector reduce_axis_list; + if (!IsMultiAxisSelfDependence(pass_info.dependences_, tuple_id)) { + is_tuple_reduce_op[tuple_id_key] = false; + } else { + is_tuple_reduce_op[tuple_id_key] = + CheckIsStmtReduceOp(scop_info.analysis_result_.GetReads(), scop_info.analysis_result_.GetWrites(), + tuple_id, reduce_axis_list) || + CheckIsStmtReduceOp(pass_info.dependences_, tuple_id, reduce_axis_list); + } + + if (is_tuple_reduce_op[tuple_id_key]) { + scop_info.analysis_result_.RecordReduceStmt(tuple_id, reduce_axis_list); + } } - - if (is_tuple_reduce_op[tuple_id_key]) { - scop_.RecordReduceStmt(tuple_id, reduce_axis_list); + if (!is_tuple_reduce_op[tuple_id_key]) { + preserved_dependences = preserved_dependences.add_map(m); } } - if (!is_tuple_reduce_op[tuple_id_key]) { - preserved_dependences = preserved_dependences.add_map(m); - } - } - }); + }); return preserved_dependences; } /* * Removes all self dependences in the program. Use with special care. */ -isl::union_map Transform::RemoveSelfDependence() { - isl::union_map preserved = isl::union_map::empty(dependences_.get_space()); - isl::union_map removed = isl::union_map::empty(dependences_.get_space()); - dependences_.foreach_map([&](const isl::map m) -> void { +isl::union_map RemoveSelfDependence(PassInfo &pass_info) { + isl::union_map preserved = isl::union_map::empty(pass_info.dependences_.get_space()); + isl::union_map removed = isl::union_map::empty(pass_info.dependences_.get_space()); + pass_info.dependences_.foreach_map([&](const isl::map &m) -> void { if (m.domain().get_tuple_id() != m.range().get_tuple_id()) { preserved = preserved.add_map(m); } else { @@ -694,6 +695,50 @@ isl::union_map Transform::RemoveSelfDependence() { return preserved; } +isl::union_map RemoveInvariantDependence(const isl::schedule &schedule, PassInfo &pass_info) { + isl::schedule_node root = schedule.get_root(); + isl::schedule_node outer_band = GetOuterBand(root); + if (outer_band.as() || outer_band.as()) { + for (unsigned int i = 0; i < outer_band.n_children(); ++i) { + isl::schedule_node node = outer_band.get_child(i); + auto filter = node.as(); + isl::union_set sets = filter.filter(); + if (sets.n_set() == 1) { + sets.foreach_set([&pass_info](const isl::set &s) -> void { + if (s.n_dim() == 0) { + // scalar single filter + if (pass_info.invariant_state_.count(s.get_tuple_name()) == 0) { + pass_info.invariant_state_.emplace(s.get_tuple_name(), 1); + } + } + }); + } + } + } + + if (pass_info.invariant_state_.empty()) { + return pass_info.dependences_; + } + + isl::union_map preserved = isl::union_map::empty(pass_info.dependences_.get_space()); + + pass_info.dependences_.foreach_map([&preserved, &pass_info](const isl::map &m) -> void { + auto map_domain = m.domain(); + auto map_range = m.range(); + bool invariant_dependence = (pass_info.invariant_state_.count(map_domain.get_tuple_name()) > 0) && + (map_domain.n_dim() == 0) && (map_range.n_dim() > 0); + + if (invariant_dependence) { + pass_info.has_invariant_dependence_ = true; + } + + if (!invariant_dependence) { + preserved = preserved.add_map(m); + } + }); + return preserved; +} + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/rm_self_dep.h b/src/poly/schedule_pass/rm_self_dep.h similarity index 97% rename from src/poly/rm_self_dep.h rename to src/poly/schedule_pass/rm_self_dep.h index f9f8c33182163a31f1c31443f772a76d63b210a9..93de3cbdeef133d9cf0c4ec3237115081f15edde 100644 --- a/src/poly/rm_self_dep.h +++ b/src/poly/schedule_pass/rm_self_dep.h @@ -16,8 +16,8 @@ #ifndef POLY_RM_SELF_DEP_H_ #define POLY_RM_SELF_DEP_H_ -#pragma once -#include "poly/transform.h" +#include +#include namespace akg { namespace ir { diff --git a/src/poly/schedule_pass/set_all_coincidence.cc b/src/poly/schedule_pass/set_all_coincidence.cc new file mode 100644 index 0000000000000000000000000000000000000000..39f1da1fc5d1d8ac5393599ea28a5257d8f8bba6 --- /dev/null +++ b/src/poly/schedule_pass/set_all_coincidence.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "set_all_coincidence.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule SetAllCoincidence::Run(isl::schedule curr_schedule) { + const auto &new_schedule = curr_schedule; + auto fn = [](isl::schedule_node node) -> isl::schedule_node { + if (auto band = node.as()) { + int num_dims = static_cast(band.n_member()); + for (int dim = 0; dim < num_dims; ++dim) { + bool is_coincident = band.member_get_coincident(dim); + if (is_coincident) continue; + node = band.member_set_coincident(dim, true); + band = node.as(); + } + } + return node; + }; + return new_schedule.get_root().map_descendant_bottom_up(fn).get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/set_all_coincidence.h b/src/poly/schedule_pass/set_all_coincidence.h new file mode 100644 index 0000000000000000000000000000000000000000..7f924ffda8038426ec199dbe91b0d7e56dadab95 --- /dev/null +++ b/src/poly/schedule_pass/set_all_coincidence.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_SET_ALL_COINCIDENCE_H_ +#define POLY_SET_ALL_COINCIDENCE_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Sometimes, coincident is set to `0` for some axes that can actually be parallelised in computed schedule tree. + * Since we have no idea why these cases happen, we offer such transfrom to set all coincident to `1`. + * Please be careful to do such transfrom since it may cause some incorrect result. + */ +class SetAllCoincidence : public SchedulePass { + public: + SetAllCoincidence() { pass_name_ = __FUNCTION__; }; + ~SetAllCoincidence(){}; + + virtual isl::schedule Run(isl::schedule sch); +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_SET_ALL_COINCIDENCE_H_ diff --git a/src/poly/schedule_pass/sink_c0.cc b/src/poly/schedule_pass/sink_c0.cc new file mode 100644 index 0000000000000000000000000000000000000000..dcf2e53733371027163bf3ed1001f66a60f5d7fb --- /dev/null +++ b/src/poly/schedule_pass/sink_c0.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sink_c0.h" + +namespace akg { +namespace ir { +namespace poly { + +bool SinkC0::FindC0Schedule(const isl::pw_aff_list &paList) { + for (unsigned int upaIdx = 0; upaIdx < paList.size(); ++upaIdx) { + isl::pw_aff pa = paList.get_at(upaIdx); + int64_t inDimSize = isl_pw_aff_dim(pa.get(), isl_dim_in); + CHECK_NE(inDimSize, -1); + const char *lastInDim = isl_pw_aff_get_dim_name(pa.get(), isl_dim_in, inDimSize - 1); + if (lastInDim == nullptr) { + continue; + } + std::string lastAxis = lastInDim; + // pw_aff { S_4[n, c1, kh, oh, c0] -> [(n)] } + // to do use isl api to mark schedule axis + std::string pwAffStr = pa.to_str(); + std::size_t arrowPos = pwAffStr.find("->"); + if (arrowPos == std::string::npos) { + continue; + } + std::string rang = pwAffStr.substr(arrowPos + 2, pwAffStr.size() - (arrowPos + 2)); + std::size_t leftBracket = rang.find("("); + std::size_t rightBracket = rang.find(")"); + if ((leftBracket == std::string::npos) || (rightBracket == std::string::npos) || + (rightBracket <= leftBracket + 1)) { + continue; + } + std::string scheduleAxis = rang.substr(leftBracket + 1, rightBracket - leftBracket - 1); + if (lastAxis == scheduleAxis) { + // lastIdxSchedule[i] = true; + // findC0Schedule = true; + // break; + return true; + } + } + return false; +} + +void SinkC0::ExchangeCoincident(std::vector &coincident, const isl::schedule_node &node, + const std::unordered_map lastIdxSchedule, const int &n) { + // save coincident value for this band + std::vector coincidentOld; + for (int i = 0; i < n; ++i) { + coincidentOld.push_back(node.as().member_get_coincident(i)); + } + + // exchange last axis coincident to last position + for (int i = 0; i < n; ++i) { + if (lastIdxSchedule.count(i) > 0) { + continue; + } + coincident.push_back(coincidentOld[i]); + } + + for (auto item : lastIdxSchedule) { + CHECK_GE(item.first, 0) << "index of coincident can not be negative: " << item.first; + coincident.push_back(coincidentOld[item.first]); + } +} + +/* ***************************************************** + * Initialization part: + * get partial_schedule info and union_pw_aff_list from band node + * partial_schedule is a multi_union_pw_aff as follows: + * [ + { S_4[n, c1, kh, oh, c0] -> [(n)]; S_3[n, c1, oh, ow, c0] -> [(n)]; S_5[n, c1, kh, oh, ow, c0] -> [(n)]; S_6[n, +c1, kh, kw, oh, ow, c0] -> [(n)] }, { S_4[n, c1, kh, oh, c0] -> [(c1)]; S_3[n, c1, oh, ow, c0] -> [(c1)]; S_5[n, c1, kh, +oh, ow, c0] -> [(c1)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(c1)] }, { S_4[n, c1, kh, oh, c0] -> [(oh)]; S_3[n, c1, oh, +ow, c0] -> [(oh)]; S_5[n, c1, kh, oh, ow, c0] -> [(oh)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(oh)] }, { S_4[n, c1, kh, +oh, c0] -> [(0)]; S_3[n, c1, oh, ow, c0] -> [(ow)]; S_5[n, c1, kh, oh, ow, c0] -> [(1 + ow)]; S_6[n, c1, kh, kw, oh, ow, +c0] -> [(ow)] }, { S_4[n, c1, kh, oh, c0] -> [(c0)]; S_3[n, c1, oh, ow, c0] -> [(c0)]; S_5[n, c1, kh, oh, ow, c0] -> +[(c0)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(c0)] }, { S_4[n, c1, kh, oh, c0] -> [(kh)]; S_3[n, c1, oh, ow, c0] -> [(0)]; +S_5[n, c1, kh, oh, ow, c0] -> [(kh)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(kh)] }, { S_4[n, c1, kh, oh, c0] -> [(0)]; +S_3[n, c1, oh, ow, c0] -> [(0)]; S_5[n, c1, kh, oh, ow, c0] -> [(0)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(-kw)] } + ] + * Is union_pw_aff_list(upal) the other form of multi_union_pw_aff ? and it can not print in LOG(INFO) + * but we need it during update, at least we make a new multi_union_pw_aff from union_pw_aff_list + * and add it to the band node, shown in the following pseudo-code + * isl::union_pw_aff_list upal = isl::union_pw_aff_list(); + * ... ... + * update strategy of upal ... + * ... ... + * isl::multi_union_pw_aff mupa = isl::multi_union_pw_aff(partial_schedule.get_space(), upal); + * node = node.del(); + * node = node.insert_partial_schedule(mupa); + * + * The update strategy of SinkC0 is moving the schedule of axis of C0 with every statement + * to the end of the multi_union_pw_aff, the purpose result is shown in the following: + * +[ +{ S_4[n, c1, kh, oh, c0] -> [(n)]; S_3[n, c1, oh, ow, c0] -> [(n)]; S_5[n, c1, kh, oh, ow, c0] -> [(n)]; S_6[n, c1, kh, +kw, oh, ow, c0] -> [(n)] }, { S_4[n, c1, kh, oh, c0] -> [(c1)]; S_3[n, c1, oh, ow, c0] -> [(c1)]; S_5[n, c1, kh, oh, ow, +c0] -> [(c1)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(c1)] }, { S_4[n, c1, kh, oh, c0] -> [(oh)]; S_3[n, c1, oh, ow, c0] -> +[(oh)]; S_5[n, c1, kh, oh, ow, c0] -> [(oh)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(oh)] }, { S_4[n, c1, kh, oh, c0] -> +[(0)]; S_3[n, c1, oh, ow, c0] -> [(ow)]; S_5[n, c1, kh, oh, ow, c0] -> [(1 + ow)]; S_6[n, c1, kh, kw, oh, ow, c0] -> +[(ow)] }, del { S_4[n, c1, kh, oh, c0] -> [(c0)]; S_3[n, c1, oh, ow, c0] -> [(c0)]; S_5[n, c1, kh, oh, ow, c0] -> +[(c0)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(c0)] }, | { S_4[n, c1, kh, oh, c0] -> [(kh)]; S_3[n, c1, oh, ow, c0] -> +[(0)]; S_5[n, c1, kh, oh, ow, c0] -> [(kh)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(kh)] }, v { S_4[n, c1, kh, oh, c0] -> +[(0)]; S_3[n, c1, oh, ow, c0] -> [(0)]; S_5[n, c1, kh, oh, ow, c0] -> [(0)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(-kw)] } +add { S_4[n, c1, kh, oh, c0] -> [(c0)]; S_3[n, c1, oh, ow, c0] -> [(c0)]; S_5[n, c1, kh, oh, ow, c0] -> [(c0)]; S_6[n, +c1, kh, kw, oh, ow, c0] -> [(c0)] }, +] + * This strategy is designed for Davinci architecture, for its five dimension data format. + * We suppose two steps to achieve this strategy: + * 1. find the last axis C0 schedule in the multi_union_pw_aff + * 2. if find this schedule, move it to the end of the multi_union_pw_aff + * 3. add the updated multi_union_pw_aff to the band node + * *****************************************************/ +isl::schedule_node SinkC0::SinkC0Schedule(isl::schedule_node &node) { + if (!node.isa()) { + return node; + } + auto schedule = node.as().get_partial_schedule(); + isl::union_pw_aff_list upal = isl::union_pw_aff_list(); + std::unordered_map lastIdxSchedule; + + // make new union pw aff list + for (unsigned int i = 0; i < schedule.size(); ++i) { + isl::union_pw_aff upa = schedule.get_union_pw_aff(i); + isl::pw_aff_list paList = upa.get_pw_aff_list(); + bool findC0Schedule = FindC0Schedule(paList); + if (findC0Schedule) { + lastIdxSchedule[i] = true; + continue; + } + if (upal.is_null()) { + upal = isl::union_pw_aff_list(upa); + } else { + upal = upal.add(upa); + } + } + + // save permutable value for this band + int permutable = node.as().get_permutable(); + if (!lastIdxSchedule.empty() && permutable == 1) { + for (auto idx : lastIdxSchedule) { + isl::union_pw_aff upa = schedule.get_union_pw_aff(idx.first); + if (upal.is_null()) { + upal = isl::union_pw_aff_list(upa); + } else { + upal = upal.add(upa); + } + } + } else { + return node; + } + + std::vector coincident; + int n = node.as().n_member(); + ExchangeCoincident(coincident, node, lastIdxSchedule, n); + + // make multi_union_pw_aff + isl::multi_union_pw_aff mupa = isl::multi_union_pw_aff(schedule.get_space(), upal); + + // delete old node + node = node.del(); + + // insert new node + node = node.insert_partial_schedule(mupa); + node = node.as().set_permutable(permutable); + for (int i = 0; i < n; ++i) { + node = node.as().member_set_coincident(i, coincident[i]); + } + return node; +} + +isl::schedule SinkC0::Run(isl::schedule sch) { + auto fn = [&, this](isl::schedule_node node) -> isl::schedule_node { + if (node.isa()) { + node = SinkC0Schedule(node); + } + return node; + }; + + return sch.get_root().map_descendant_bottom_up(fn).get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/sink_c0.h b/src/poly/schedule_pass/sink_c0.h new file mode 100644 index 0000000000000000000000000000000000000000..e9cae9119ebddb8f6e7c39538d0a53007ba29f94 --- /dev/null +++ b/src/poly/schedule_pass/sink_c0.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_SINK_C0_H_ +#define POLY_SINK_C0_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * For each band node in schedule tree, get multi_union_pw_aff from the current band node. Then, move the last axis C0 + * schedule to the end of this multi_union_pw_aff and add the updated multi_union_pw_aff to the current band node. + */ +class SinkC0 : public SchedulePass { + public: + SinkC0() { pass_name_ = __FUNCTION__; } + ~SinkC0() {} + + virtual isl::schedule Run(isl::schedule sch); + + private: + bool FindC0Schedule(const isl::pw_aff_list &paList); + void ExchangeCoincident(std::vector &coincident, const isl::schedule_node &node, + const std::unordered_map lastIdxSchedule, const int &n); + isl::schedule_node SinkC0Schedule(isl::schedule_node &node); +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_SINK_C0_H_ \ No newline at end of file diff --git a/src/poly/sink_axis.cc b/src/poly/schedule_pass/sink_last_axis.cc similarity index 51% rename from src/poly/sink_axis.cc rename to src/poly/schedule_pass/sink_last_axis.cc index 668363977b7ca1bb68fc7a5a912c1217b672b019..1c47150c1522fa1533d684f9c30765cf7e3aa6be 100644 --- a/src/poly/sink_axis.cc +++ b/src/poly/schedule_pass/sink_last_axis.cc @@ -13,198 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "sink_last_axis.h" -#include "poly/sink_axis.h" - -#include -#include #include -#include -#include -#include -#include - -#include "poly/dump_log.h" - namespace akg { namespace ir { namespace poly { -bool FindC0Schedule(const isl::pw_aff_list &paList) { - for (unsigned int upaIdx = 0; upaIdx < paList.size(); ++upaIdx) { - isl::pw_aff pa = paList.get_at(upaIdx); - int64_t inDimSize = isl_pw_aff_dim(pa.get(), isl_dim_in); - CHECK_NE(inDimSize, -1); - const char *lastInDim = isl_pw_aff_get_dim_name(pa.get(), isl_dim_in, inDimSize - 1); - if (lastInDim == nullptr) { - continue; - } - std::string lastAxis = lastInDim; - // pw_aff { S_4[n, c1, kh, oh, c0] -> [(n)] } - // to do use isl api to mark schedule axis - std::string pwAffStr = pa.to_str(); - std::size_t arrowPos = pwAffStr.find("->"); - if (arrowPos == std::string::npos) { - continue; - } - std::string rang = pwAffStr.substr(arrowPos + 2, pwAffStr.size() - (arrowPos + 2)); - std::size_t leftBracket = rang.find("("); - std::size_t rightBracket = rang.find(")"); - if ((leftBracket == std::string::npos) || (rightBracket == std::string::npos) || - (rightBracket <= leftBracket + 1)) { - continue; - } - std::string scheduleAxis = rang.substr(leftBracket + 1, rightBracket - leftBracket - 1); - if (lastAxis == scheduleAxis) { - // lastIdxSchedule[i] = true; - // findC0Schedule = true; - // break; - return true; - } - } - return false; -} - -void ExchangeCoincident(std::vector &coincident, const isl::schedule_node &node, - const std::unordered_map lastIdxSchedule, const int &n) { - // save coincident value for this band - std::vector coincidentOld; - for (int i = 0; i < n; ++i) { - coincidentOld.push_back(node.as().member_get_coincident(i)); - } - - // exchange last axis coincident to last position - for (int i = 0; i < n; ++i) { - if (lastIdxSchedule.count(i) > 0) { - continue; - } - coincident.push_back(coincidentOld[i]); - } - - for (auto item : lastIdxSchedule) { - CHECK_GE(item.first, 0) << "index of coincident can not be negative: " << item.first; - coincident.push_back(coincidentOld[item.first]); - } -} - -/* ***************************************************** - * Initialization part: - * get partial_schedule info and union_pw_aff_list from band node - * partial_schedule is a multi_union_pw_aff as follows: - * [ - { S_4[n, c1, kh, oh, c0] -> [(n)]; S_3[n, c1, oh, ow, c0] -> [(n)]; S_5[n, c1, kh, oh, ow, c0] -> [(n)]; S_6[n, -c1, kh, kw, oh, ow, c0] -> [(n)] }, { S_4[n, c1, kh, oh, c0] -> [(c1)]; S_3[n, c1, oh, ow, c0] -> [(c1)]; S_5[n, c1, kh, -oh, ow, c0] -> [(c1)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(c1)] }, { S_4[n, c1, kh, oh, c0] -> [(oh)]; S_3[n, c1, oh, -ow, c0] -> [(oh)]; S_5[n, c1, kh, oh, ow, c0] -> [(oh)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(oh)] }, { S_4[n, c1, kh, -oh, c0] -> [(0)]; S_3[n, c1, oh, ow, c0] -> [(ow)]; S_5[n, c1, kh, oh, ow, c0] -> [(1 + ow)]; S_6[n, c1, kh, kw, oh, ow, -c0] -> [(ow)] }, { S_4[n, c1, kh, oh, c0] -> [(c0)]; S_3[n, c1, oh, ow, c0] -> [(c0)]; S_5[n, c1, kh, oh, ow, c0] -> -[(c0)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(c0)] }, { S_4[n, c1, kh, oh, c0] -> [(kh)]; S_3[n, c1, oh, ow, c0] -> [(0)]; -S_5[n, c1, kh, oh, ow, c0] -> [(kh)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(kh)] }, { S_4[n, c1, kh, oh, c0] -> [(0)]; -S_3[n, c1, oh, ow, c0] -> [(0)]; S_5[n, c1, kh, oh, ow, c0] -> [(0)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(-kw)] } - ] - * Is union_pw_aff_list(upal) the other form of multi_union_pw_aff ? and it can not print in LOG(INFO) - * but we need it during update, at least we make a new multi_union_pw_aff from union_pw_aff_list - * and add it to the band node, shown in the following pseudo-code - * isl::union_pw_aff_list upal = isl::union_pw_aff_list(); - * ... ... - * update strategy of upal ... - * ... ... - * isl::multi_union_pw_aff mupa = isl::multi_union_pw_aff(partial_schedule.get_space(), upal); - * node = node.del(); - * node = node.insert_partial_schedule(mupa); - * - * The update strategy of SinkC0 is moving the schedule of axis of C0 with every statement - * to the end of the multi_union_pw_aff, the purpose result is shown in the following: - * -[ -{ S_4[n, c1, kh, oh, c0] -> [(n)]; S_3[n, c1, oh, ow, c0] -> [(n)]; S_5[n, c1, kh, oh, ow, c0] -> [(n)]; S_6[n, c1, kh, -kw, oh, ow, c0] -> [(n)] }, { S_4[n, c1, kh, oh, c0] -> [(c1)]; S_3[n, c1, oh, ow, c0] -> [(c1)]; S_5[n, c1, kh, oh, ow, -c0] -> [(c1)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(c1)] }, { S_4[n, c1, kh, oh, c0] -> [(oh)]; S_3[n, c1, oh, ow, c0] -> -[(oh)]; S_5[n, c1, kh, oh, ow, c0] -> [(oh)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(oh)] }, { S_4[n, c1, kh, oh, c0] -> -[(0)]; S_3[n, c1, oh, ow, c0] -> [(ow)]; S_5[n, c1, kh, oh, ow, c0] -> [(1 + ow)]; S_6[n, c1, kh, kw, oh, ow, c0] -> -[(ow)] }, del { S_4[n, c1, kh, oh, c0] -> [(c0)]; S_3[n, c1, oh, ow, c0] -> [(c0)]; S_5[n, c1, kh, oh, ow, c0] -> -[(c0)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(c0)] }, | { S_4[n, c1, kh, oh, c0] -> [(kh)]; S_3[n, c1, oh, ow, c0] -> -[(0)]; S_5[n, c1, kh, oh, ow, c0] -> [(kh)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(kh)] }, v { S_4[n, c1, kh, oh, c0] -> -[(0)]; S_3[n, c1, oh, ow, c0] -> [(0)]; S_5[n, c1, kh, oh, ow, c0] -> [(0)]; S_6[n, c1, kh, kw, oh, ow, c0] -> [(-kw)] } -add { S_4[n, c1, kh, oh, c0] -> [(c0)]; S_3[n, c1, oh, ow, c0] -> [(c0)]; S_5[n, c1, kh, oh, ow, c0] -> [(c0)]; S_6[n, -c1, kh, kw, oh, ow, c0] -> [(c0)] }, -] - * This strategy is designed for Davinci architecture, for its five dimension data format. - * We suppose two steps to achieve this strategy: - * 1. find the last axis C0 schedule in the multi_union_pw_aff - * 2. if find this schedule, move it to the end of the multi_union_pw_aff - * 3. add the updated multi_union_pw_aff to the band node - * *****************************************************/ -isl::schedule_node Transform::SinkC0Schedule(isl::schedule_node &node) { - if (!node.isa()) { - return node; - } - auto schedule = node.as().get_partial_schedule(); - isl::union_pw_aff_list upal = isl::union_pw_aff_list(); - std::unordered_map lastIdxSchedule; - - // make new union pw aff list - for (unsigned int i = 0; i < schedule.size(); ++i) { - isl::union_pw_aff upa = schedule.get_union_pw_aff(i); - isl::pw_aff_list paList = upa.get_pw_aff_list(); - bool findC0Schedule = FindC0Schedule(paList); - if (findC0Schedule) { - lastIdxSchedule[i] = true; - continue; - } - if (upal.is_null()) { - upal = isl::union_pw_aff_list(upa); - } else { - upal = upal.add(upa); - } - } - - // save permutable value for this band - int permutable = node.as().get_permutable(); - if (!lastIdxSchedule.empty() && permutable == 1) { - for (auto idx : lastIdxSchedule) { - isl::union_pw_aff upa = schedule.get_union_pw_aff(idx.first); - if (upal.is_null()) { - upal = isl::union_pw_aff_list(upa); - } else { - upal = upal.add(upa); - } - } - } else { - return node; - } - - std::vector coincident; - int n = node.as().n_member(); - ExchangeCoincident(coincident, node, lastIdxSchedule, n); - - // make multi_union_pw_aff - isl::multi_union_pw_aff mupa = isl::multi_union_pw_aff(schedule.get_space(), upal); - - // delete old node - node = node.del(); - - // insert new node - node = node.insert_partial_schedule(mupa); - node = node.as().set_permutable(permutable); - for (int i = 0; i < n; ++i) { - node = node.as().member_set_coincident(i, coincident[i]); - } - return node; -} - -isl::schedule Transform::SinkC0(const isl::schedule &sch) { - auto fn = [&, this](isl::schedule_node node) -> isl::schedule_node { - if (node.isa()) { - node = SinkC0Schedule(node); - } - return node; - }; - - return sch.get_root().map_descendant_bottom_up(fn).get_schedule(); -} - /* * Check whether the domain of the last axis is not larger than a threshold, * so that large last axes can still be tiled. @@ -385,25 +201,11 @@ static isl::schedule_node SinkLastAxisFromBand(const isl::schedule_node &outer_b return MoveLastAxisToInnermostBand(outer_band); } -/* - * Try to sink the last axis of outer band to the leaves of the schedule tree. - * - * The criteria that the last axis can be sinked: - * 1) the axis is the last axis in the outer band schedule. - * 2) the axis is the last axis in the domain of each statement. - * 3) all dependencies of the last axis are equality constraints. (i.e. S_1[c0] -> S_2[c0' = c0]) - * 4) all dependencies of the last axis do not appear in other non-last axes. - * 5) the domain of the last axis is not larger than a threshold (otherwise it still should be tiled). - * - * sinkLastAxis will: - * 1) remove the C0 axis from the outer band schedule, and - * 2) add a partial schedule (C0) to each leaf filter node that contains the last axis. - */ -isl::schedule Transform::SinkLastAxis(const isl::schedule &sch) { +isl::schedule SinkLastAxis::Run(isl::schedule sch) { auto outer_band_node = sch.get_root(); while (true) { if (outer_band_node.isa()) { - outer_band_node = SinkLastAxisFromBand(outer_band_node, dependences_); + outer_band_node = SinkLastAxisFromBand(outer_band_node, pass_info_.dependences_); break; } unsigned int n_children = outer_band_node.n_children(); @@ -417,7 +219,7 @@ isl::schedule Transform::SinkLastAxis(const isl::schedule &sch) { if (outer_band_node.child(i).n_children() == 0) continue; outer_band_node = outer_band_node.child(i).child(0); if (outer_band_node.isa()) { - outer_band_node = SinkLastAxisFromBand(outer_band_node, dependences_); + outer_band_node = SinkLastAxisFromBand(outer_band_node, pass_info_.dependences_); } outer_band_node = outer_band_node.parent().parent(); } diff --git a/src/poly/schedule_pass/sink_last_axis.h b/src/poly/schedule_pass/sink_last_axis.h new file mode 100644 index 0000000000000000000000000000000000000000..06e41620fc4def5cfe0f23cc78b012097ef43dbb --- /dev/null +++ b/src/poly/schedule_pass/sink_last_axis.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_SINK_LAST_AXIS_H_ +#define POLY_SINK_LAST_AXIS_H_ + +#include "poly/schedule_pass.h" +#include "poly/pass_info.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Try to sink the last axis of outer band to the leaves of the schedule tree. + * + * The criteria that the last axis can be sinked: + * 1) the axis is the last axis in the outer band schedule. + * 2) the axis is the last axis in the domain of each statement. + * 3) all dependencies of the last axis are equality constraints. (i.e. S_1[c0] -> S_2[c0' = c0]) + * 4) all dependencies of the last axis do not appear in other non-last axes. + * 5) the domain of the last axis is not larger than a threshold (otherwise it still should be tiled). + * + * SinkLastAxis will: + * 1) remove the C0 axis from the outer band schedule, and + * 2) add a partial schedule (C0) to each leaf filter node that contains the last axis. + */ +class SinkLastAxis : public SchedulePass { + public: + SinkLastAxis(PassInfo &pass_info) : pass_info_(pass_info) { pass_name_ = __FUNCTION__; } + ~SinkLastAxis() {} + + virtual isl::schedule Run(isl::schedule sch); + + private: + PassInfo &pass_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_SINK_LAST_AXIS_H_ \ No newline at end of file diff --git a/src/poly/schedule_pass/split_outer_band.cc b/src/poly/schedule_pass/split_outer_band.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf50acfad5442ed682b21cecf8d4c2d4b9dca4b0 --- /dev/null +++ b/src/poly/schedule_pass/split_outer_band.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "split_outer_band.h" + +namespace akg { +namespace ir { +namespace poly { + +isl::schedule SplitOuterBand::Run(isl::schedule curr_schedule) { + isl::schedule_node node = curr_schedule.get_root(); + while (!node.isa()) { + node = node.child(0); + } + isl::schedule_node_band band = node.as(); + unsigned i = 0; + unsigned n = band.n_member(); + for (; i < n; ++i) { + if (!band.member_get_coincident(i)) { + break; + } + } + if ((n <= 1) || (i == 0) || (i == n)) { + return node.get_schedule(); + } + node = band.split(i); + return node.get_schedule(); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/split_outer_band.h b/src/poly/schedule_pass/split_outer_band.h new file mode 100644 index 0000000000000000000000000000000000000000..971914b285d989c2acd6bdbe8cfd78400fdf20c5 --- /dev/null +++ b/src/poly/schedule_pass/split_outer_band.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_SPLIT_OUTER_BAND_H_ +#define POLY_SPLIT_OUTER_BAND_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Split the consecutive parallelled nodes (i.e. coincident equals to 1) from the most-outer band, + * resulting in an outer band with an inner band containing all untileable nodes as its child. + * Note that this transfrom can prevent shift case when post-fusion exists in dynamic shape. + */ +class SplitOuterBand : public SchedulePass { + public: + SplitOuterBand() { pass_name_ = __FUNCTION__; }; + ~SplitOuterBand(){}; + + virtual isl::schedule Run(isl::schedule sch); +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_SPLIT_OUTER_BAND_H_ diff --git a/src/poly/schedule_pass/tile_outer_band.cc b/src/poly/schedule_pass/tile_outer_band.cc new file mode 100644 index 0000000000000000000000000000000000000000..83fa9392efe4d046663730dc1ebf9f0cd0f1fd8c --- /dev/null +++ b/src/poly/schedule_pass/tile_outer_band.cc @@ -0,0 +1,937 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tile_outer_band.h" + +#include "poly/scop.h" +#include "poly/schedule_pass/transfer_stmt.h" +#include "poly/schedule_pass/try_mark_scalar_stmt.h" + +namespace akg { +namespace ir { +namespace poly { + +class DimInfoMatcher : public IRVisitor { + public: + DimInfoMatcher() = default; + ~DimInfoMatcher() override = default; + + std::string dim() { return dim_; } + + void Visit_(const AttrStmt *op) final { + if (const auto Cop = op->node.as()) { + for (auto iter : Cop->attrs) { + if (dim_.empty() && iter.first == "dim") { + if (auto dim = iter.second.as()) { + dim_ = dim->value; + break; + } + } + } + } + } + + private: + std::string dim_ = ""; +}; + +std::string TileOuterBand::GetcDim() { + auto matcher = DimInfoMatcher(); + matcher.Visit(scop_info_.user_config_.GetBody()); + return matcher.dim(); +} + +// Init set_dim info +void TileOuterBand::InitDimensionInfo(const isl::schedule &sch_init) { + // get compute dim + std::string dim = GetcDim(); + // get build dim + if (dim.empty()) { + dim = GetbDim(); + } + + // apply default tiling + if (dim.empty()) { + auto tiling_res = GenerateTiling(sch_init, scop_info_, GenHalide(scop_info_, sch_init, true)); + scop_info_.analysis_result_.SetTileSizes(tiling_res.first); + scop_info_.analysis_result_.SetTileConstraints(tiling_res.second); + if (scop_info_.cube_info_.IsConv()) scop_info_.cube_info_.SetConvMNKInfo(); + return; + } + + const std::string pattern = " "; + std::vector str = Split(dim, pattern); + const int dim_info_entry_size = 4; + CHECK(!str.empty() && !(str.size() % dim_info_entry_size)) << "Error: You need to set dim !"; + int sequence = 0; + for (size_t i = 0; i < str.size(); i += dim_info_entry_size) { + DimensionInfo dim_info; + char *endptr = nullptr; + const int radix = 10; + dim_info.index = strtol(str[i].c_str(), &endptr, radix); + if (endptr == nullptr || *endptr != '\0') LOG(FATAL) << "failed to convert string " << str[i] << " to number"; + const int max_dim_index = 16; + CHECK(dim_info.index < max_dim_index) << "set_dim index must be less than " << max_dim_index << "!"; + dim_info.axis = str[i + 1]; + const int default_tiling_size = 65535; + endptr = nullptr; + int64_t str_2_number = strtol(str[i + 2].c_str(), &endptr, radix); + if (endptr == nullptr || *endptr != '\0' || str_2_number <= 0) { + dim_info.l1_tiling_size = default_tiling_size; + } else { + dim_info.l1_tiling_size = str_2_number; + } + endptr = nullptr; + int64_t str_3_number = strtol(str[i + 3].c_str(), &endptr, radix); + if (endptr == nullptr || *endptr != '\0' || str_3_number <= 0) { + dim_info.l0_tiling_size = default_tiling_size; + } else { + dim_info.l0_tiling_size = str_3_number; + } + dim_info.dim_seq = sequence; + sequence++; + scop_info_.analysis_result_.InsertDimensionInfo(dim_info); + } +} + +void TileOuterBand::MergeTilingInfo() { + int64_t tiles_num = 0; + auto tile_sizes = scop_info_.analysis_result_.GetTileSizes(); + for (unsigned i = 0; i < tile_sizes.size(); ++i) { + if (tiles_num <= tile_sizes[i].index) { + tiles_num = tile_sizes[i].index + 1; + } + } + tiles_.resize((size_t)tiles_num); + + for (unsigned i = 0; i < tile_sizes.size(); ++i) { + tiles_[(unsigned int)tile_sizes[i].index].dim_infos.push_back(tile_sizes[i]); + } +} + +std::vector> TileOuterBand::AddTileInfo(const std::vector> &partition_info) { + std::vector> info; + PartitionSingle *single = PartitionSingle::getInstance(); + if (single == nullptr) { + return partition_info; + } else if (PartitionSingle::getTimes() < 2) { + // first time gemm or m isolate main gemm + return partition_info; + } + + for (auto it : partition_info) { + info.push_back(it); + } + return info; +} + +isl::schedule TileOuterBand::Run(isl::schedule sch) { + auto map_before_tile = sch.get_map(); + // TransferStmt pass + isl::schedule tiling_schedule = sch; + if (!scop_info_.cube_info_.IsSpecGemm()) { + tiling_schedule = TransferStmt(scop_info_, pass_info_).Run(tiling_schedule); + } + scop_info_.analysis_result_.InitScheduleMapBeforeTile(scop_info_.GetCtx()); + if (!scop_info_.cube_info_.IsSpecGemm() && (scop_info_.cube_info_.IsConv() || scop_info_.cube_info_.IsGemm())) { + scop_info_.analysis_result_.SetScheduleMapBeforeTile(sch.get_map()); + } + InitDimensionInfo(tiling_schedule); + MergeTilingInfo(); + + isl::schedule_node root = sch.get_root(); + + // 1. obtain the outermost tilable band + isl::schedule_node node = GetOuterBand(root); + + ShowDimInfo(); + + // 2. Traverse the descendants of "node" (including the node itself) + // in depth first postorder via the callback function. + using std::placeholders::_1; + const std::function f = + std::bind(&TileOuterBand::MarkOuterPermutable, this, _1); + + if (node.isa()) { + tile_sizes_ = tiles_[0].dim_infos; + node = node.map_descendant_bottom_up(f); + } else { + // multiple outer bands, use same filter strategy as in auto tiling + unsigned int band_idx = 0; + for (auto i = 0; i < static_cast(node.n_children()); ++i) { + tile_sizes_ = band_idx < tiles_.size() ? tiles_[band_idx].dim_infos : tiles_[0].dim_infos; + if (node.get_child(i).isa()) { + auto filter = node.get_child(i).as(); + if (!filter.get_filter().is_empty() && filter.has_children() && + filter.get_child(0).isa()) { + band_idx += 1; + } + } + node = node.child(i).map_descendant_bottom_up(f); + node = node.parent(); + } + } + scop_info_.AddPartitionInfoToData(AddTileInfo(partition_info_)); + scop_info_.analysis_result_.SetIsTiled(true); + + auto final_schedule = node.get_schedule(); + if (final_schedule.get_map().is_equal(map_before_tile) && + (pass_info_.coincident_ || scop_info_.user_config_.GetConsiderCoincidence())) { + restart_ = true; + } else if (sch.plain_is_equal(final_schedule)) { + pass_info_.tile_check_coincident_ = scop_info_.user_config_.GetTileCheckCoincident(); + final_schedule = TryMarkScalarStmt(pass_info_).Run(final_schedule); + } + + return final_schedule; +} + +void TileOuterBand::ShowDimInfo() { + for (size_t i = 0; i < tiles_.size(); ++i) { + LOG(INFO) << "No: " << i << ", tiling_flag: " << tiles_[i].tiling_flag; + + for (const auto &dim_info : tiles_[i].dim_infos) { + std::stringstream ss; + ss << "index: " << dim_info.index << ", axis: " << dim_info.axis << ", l1_size: " << dim_info.l1_tiling_size + << ", l0_size: " << dim_info.l0_tiling_size << ", seq: " << dim_info.dim_seq + << ", is inner: " << dim_info.is_inner; + if (dim_info.l1_var.defined()) ss << ", l1_var: " << dim_info.l1_var; + if (dim_info.l0_var.defined()) ss << ", l0_var: " << dim_info.l0_var; + LOG(INFO) << ss.str(); + } + } +} + +bool TileOuterBand::IsPermutable(const isl::schedule_node &node, bool checkCoincident) { + if (!node) return false; + if (!node.isa()) return false; + if (!node.as().get_permutable()) return false; + if (node.as().n_member() < 1) return false; + return !(checkCoincident && !node.as().member_get_coincident(0)); +} + +isl::schedule_node TileOuterBand::InsertEmptyPermutableBand(isl::schedule_node node) { + isl::space space; + isl::multi_union_pw_aff mupa; + + space = node.get_schedule().get_domain().get_space(); + + space = space.set_from_params(); + mupa = isl::multi_union_pw_aff::zero(space); + node = node.insert_partial_schedule(mupa); + node = node.as().set_permutable(1); + + return node; +} + +bool TileOuterBand::SubtreeHasPermutableBands(const isl::schedule_node &node) { + bool all_non_permutable = false; + all_non_permutable = node.every_descendant([&, this](const isl::schedule_node &node) -> bool { + return BoolNot(IsPermutable(node, scop_info_.user_config_.GetTileCheckCoincident())); + }); + + return BoolNot(all_non_permutable); +} + +int TileOuterBand::IsCandidate(const isl::schedule_node &node) { + int permutable; + + if (node.isa()) return 1; + permutable = static_cast(IsPermutable(node, scop_info_.user_config_.GetTileCheckCoincident())); + if (permutable) return permutable; + if (node.isa()) return 0; + permutable = static_cast(SubtreeHasPermutableBands(node)); + if (permutable < 0) return -1; + return static_cast(!permutable); +} + +int TileOuterBand::IsOuterTilable(const isl::schedule_node &node) { + int tilable; + isl::schedule_node ancestor; + + tilable = IsCandidate(node); + if (tilable < 0) return -1; + if (!tilable) return 0; + + tilable = 0; + ancestor = node; + while (ancestor.has_parent()) { + ancestor = ancestor.parent(); + + tilable = IsCandidate(ancestor); + if (tilable) break; + } + + return static_cast(BoolNot(static_cast(tilable))); +} + +isl::schedule_node TileOuterBand::MarkTileBand(isl::schedule_node node, TileType tile_type) { + std::string markTag; + + if (tile_type == TileType::L0) { + markTag = REALIZE_L0; + node = node.insert_mark(isl::id(node.ctx(), markTag)); +#if SPEC_GEMM + if (scop_info_.cube_info_.IsConv()) { + std::string mark_tag_gmm = CONV_GEMM; + node = node.insert_mark(isl::id(node.ctx(), mark_tag_gmm)); + } +#endif + } + if (tile_type == TileType::L1) { + markTag = REALIZE_L1; + node = node.insert_mark(isl::id(node.ctx(), markTag)); + } + if (tile_type == TileType::UB) { + markTag = REALIZE_UB; + node = node.insert_mark(isl::id(node.ctx(), markTag)); + } + if (tile_type == TileType::UBL0) { + markTag = REALIZE_UBL0; + node = node.insert_mark(isl::id(node.ctx(), markTag)); + } + if (tile_type == TileType::UBL1) { + markTag = REALIZE_UBL1; + node = node.insert_mark(isl::id(node.ctx(), markTag)); + } + if (tile_type == TileType::L1UBL1) { + markTag = REALIZE_L1UBL1; + node = node.insert_mark(isl::id(node.ctx(), markTag)); + } + + return node; +} + +isl::multi_val TileOuterBand::MultiValFromIntList(const isl::space &space, int dim, const int *list) { + int i; + isl::multi_val mv; + + isl::ctx ctx = space.ctx(); + mv = isl::multi_val::zero(space); + for (i = 0; i < dim; ++i) { + mv = mv.set_val(i, isl::val(ctx, list[i])); + } + + return mv; +} + +/* Build tile map which maps the elements of the original band + * to applied tile, with the form: + * [[outer] -> [orig]] -> [[outer] -> [tile]]. + */ +isl::map TileOuterBand::ComputeTileMap(const isl::schedule_node &original_node, const isl::schedule_node &tiled_node) { + isl::union_map original_umap = original_node.as().get_partial_schedule_union_map(); + unsigned int depth = original_node.get_schedule_depth(); + + isl::space space = original_umap.get_space().params().set_from_params(); + space = space.add_dims(isl_dim_set, depth); + space = space.map_from_set(); + + isl::multi_aff maff = isl::multi_aff::identity(space); + isl::union_map tiled_umap = tiled_node.as().get_partial_schedule_union_map(); + tiled_umap = original_umap.reverse().apply_range(tiled_umap); + isl::multi_union_pw_aff tiling = isl::multi_union_pw_aff::from_union_map(tiled_umap); + + isl::map el2tile = isl::map::from(isl::union_map::from(tiling)); + el2tile = isl::map::from(isl::union_map(isl::map::from(maff)).product(el2tile)); + + return el2tile; +} + +/* + * Compute full tiles + */ +std::pair TileOuterBand::ComputeFullTile(const isl::schedule_node &original_node, + const isl::schedule_node &tiled_node) { + isl::map el2tile = ComputeTileMap(original_node, tiled_node); + isl::map tile2el = el2tile.reverse(); + + isl::union_map prefix = original_node.as().get_prefix_schedule_union_map(); + isl::union_set domain = original_node.as().get_domain(); + isl::union_map original_schedule = original_node.as().get_partial_schedule_union_map(); + isl::multi_union_pw_aff mupa = isl::multi_union_pw_aff::from_union_map(original_schedule); + + isl::union_map schedule = isl::union_map::from(mupa); + schedule = prefix.range_product(schedule); + + isl::set all_el = isl::set::from_union_set(domain.apply(schedule)); + all_el = all_el.coalesce(); + + isl::set all = all_el.apply(el2tile); + + isl::set partial = all.apply(tile2el); + partial = partial.subtract(all_el); + partial = partial.apply(el2tile); + + return {all.subtract(partial), all}; +} + +void TileOuterBand::IsolateLevelInfo(TileType &tile_type, isl::set &tiles, isl::set &all) { + // which level do we need isolate info? + if (TileType::L1 == tile_type || TileType::UB == tile_type) { + partition_info_.clear(); + auto tiles_hull = tiles.simple_hull(); + auto tiles_lexmin = tiles_hull.lexmin().simple_hull(); + auto tiles_lexmax = tiles_hull.lexmax().simple_hull(); + auto all_lexmax = all.simple_hull().lexmax().simple_hull(); + for (int i = 0; i < static_cast(tiles.n_dim()); ++i) { + std::vector part; + partition_info_.push_back(part); + partition_info_[i].push_back(0); + + int edge = static_cast(tiles_lexmin.dim_max_val(i).get_num_si()); + if (edge > partition_info_[i].back()) partition_info_[i].push_back(edge); + + edge = static_cast(tiles_lexmax.dim_max_val(i).get_num_si()) + 1; + if (edge > partition_info_[i].back()) partition_info_[i].push_back(edge); + + edge = static_cast(all_lexmax.dim_max_val(i).get_num_si()) + 1; + if (edge > partition_info_[i].back()) partition_info_[i].push_back(edge); + } + } +} + +/* + * Set the non-isolated loop type to the isolated part. + */ +isl::schedule_node TileOuterBand::SetIsolateLoopType(isl::schedule_node node) { + int i, n; + + if (!node.isa()) return node; + + n = static_cast(node.as().n_member()); + for (i = 0; i < n; ++i) { + enum isl_ast_loop_type type; + + type = isl_schedule_node_band_member_get_ast_loop_type(node.get(), i); + if (type == isl_ast_loop_default) node = node.as().member_set_isolate_ast_loop_default(i); + if (type == isl_ast_loop_atomic) node = node.as().member_set_isolate_ast_loop_atomic(i); + if (type == isl_ast_loop_unroll) node = node.as().member_set_isolate_ast_loop_unroll(i); + if (type == isl_ast_loop_separate) + node = node.as().member_set_isolate_ast_loop_separate(i); + else + return node; + } + + return node; +} + +/* Isolate tiles on demand. + */ +isl::schedule_node TileOuterBand::IsolateTiles(const isl::schedule_node &original_node, isl::schedule_node tiled_node, + TileType tile_type, const int *full_tile_min, const int *full_tile_max) { + CHECK(tiled_node.isa()); + int in, depth, dim; + isl::space space; + isl::set tiles, all; + isl::map map; + isl::set set; + isl::union_set opt; + isl::multi_aff ma1, ma2; + + // If not tiled, return + if (original_node.is_equal(tiled_node)) return tiled_node; + + depth = tiled_node.get_schedule_depth(); + dim = static_cast(tiled_node.as().n_member()); + + // compute a set "tiles" for all full tiles + std::tie(tiles, all) = ComputeFullTile(original_node, tiled_node); + if (nullptr != full_tile_min) { + unsigned int n_dim = tiles.n_dim(); + for (int i = 0; i < dim; ++i) { + if (0 == full_tile_min[i]) continue; + tiles = isl::manage( + isl_set_lower_bound_si(tiles.copy(), isl_dim_set, (n_dim - (unsigned int)(dim - i)), full_tile_min[i])); + } + } + if (nullptr != full_tile_max) { + unsigned int n_dim = tiles.n_dim(); + for (int i = 0; i < dim; ++i) { + if (MAX_STRIDE == full_tile_max[i]) continue; + tiles = isl::manage( + isl_set_upper_bound_si(tiles.copy(), isl_dim_set, (n_dim - (unsigned int)(dim - i)), full_tile_max[i])); + } + } + + IsolateLevelInfo(tile_type, tiles, all); + + map = tiles.unwrap(); + in = static_cast(map.dim(isl_dim_in)); + auto out = map.dim(isl_dim_out); + + auto upos = static_cast(depth - in); + auto udim = static_cast(dim); + map = map.project_out(isl_dim_out, (upos + udim), out - (upos + udim)); + + space = map.get_space().range(); + + ma1 = isl::multi_aff::project_out_map(space, isl_dim_set, upos, udim); + ma2 = isl::multi_aff::project_out_map(space, isl_dim_set, 0, upos); + ma1 = ma1.range_product(ma2); + + map = map.apply_range(isl::map(ma1)); + map = map.uncurry(); + map = map.flatten_domain(); + + set = map.wrap(); + set = set.set_tuple_name("isolate"); + + opt = tiled_node.as().get_ast_build_options(); + opt = opt.add_set(set); + tiled_node = tiled_node.as().set_ast_build_options(opt); + tiled_node = SetIsolateLoopType(tiled_node); + + return tiled_node; +} + +isl::multi_val TileOuterBand::ComputeBandTilesSizes(const isl::schedule_node &node, const int *tile_size) { + isl::space space; + + space = node.as().get_space(); + auto dim = static_cast(node.as().n_member()); + return MultiValFromIntList(space, dim, tile_size); +} + +isl::schedule_node TileOuterBand::TileBand(isl::schedule_node node, const isl::multi_val &sizes, TileType tile_type, + const int *full_tile_min, const int *full_tile_max, bool isolation) { + isl::ctx ctx = node.ctx(); + int scale_tile; + int shift_point; + + if (!node.isa()) { + return node; + } + scale_tile = isl_options_get_tile_scale_tile_loops(ctx.get()); + isl_stat status = isl_options_set_tile_scale_tile_loops(ctx.get(), 0); + CHECK(status == isl_stat_ok); + shift_point = isl_options_get_tile_shift_point_loops(ctx.get()); + status = isl_options_set_tile_shift_point_loops(ctx.get(), 1); + CHECK(status == isl_stat_ok); + + isl::schedule_node before_tile = node; + node = node.as().tile(sizes); + + if (!scop_info_.user_config_.GetIsDynamic() || scop_info_.cube_info_.IsSpecGemm()) { + if ((!scop_info_.user_config_.GetTileSizeIsVar()) && (isolation)) { + node = IsolateTiles(before_tile, node, tile_type, full_tile_min, full_tile_max); + } + } + + status = isl_options_set_tile_scale_tile_loops(ctx.get(), scale_tile); + CHECK(status == isl_stat_ok); + status = isl_options_set_tile_shift_point_loops(ctx.get(), shift_point); + CHECK(status == isl_stat_ok); + return node; +} + +void TileOuterBand::TileTypeL0(isl::schedule_node &node, int *full_tile_min, int *full_tile_max, TileType &tile_type, + bool &isolate, isl::multi_val &sizes) { + isl::set_list domain_list = node.get_domain().get_set_list(); + isl::union_set filter_cube = isl::union_set(); + isl::union_set filter_after_cube = isl::union_set(); + + unsigned int cube_index = 0; + for (; cube_index < scop_info_.analysis_result_.stmt_type_.size() - 1; ++cube_index) { + if (scop_info_.analysis_result_.stmt_type_[cube_index].second == STMT_OP_TYPE::CUBE_CONV || + scop_info_.analysis_result_.stmt_type_[cube_index].second == STMT_OP_TYPE::CUBE_GEMM || + scop_info_.analysis_result_.stmt_type_[cube_index].second == STMT_OP_TYPE::IM2COL_UB) { + break; + } + } + std::vector filter_before_cube; + + for (unsigned int set_index = 0; set_index < domain_list.size(); ++set_index) { + isl::set set_i = domain_list.get_at(set_index); + std::string name = set_i.get_tuple_name(); + CHECK(name.find('_') != std::string::npos) << "invalid name " << name; + unsigned int index = WrappedStrtol(name.substr(name.find('_') + 1)); + set_i = isl::manage(isl_set_eliminate_dims(set_i.copy(), 0, isl_set_n_dim(set_i.get()))); + if (index + 1 < cube_index) { + filter_before_cube.resize(cube_index - 1); + filter_before_cube[index] = isl::union_set(set_i); + } + if (index + 1 == cube_index || index == cube_index) { + filter_cube = filter_cube.is_null() ? isl::union_set(set_i) : filter_cube.add_set(set_i); + } + if (index > cube_index) { + filter_after_cube = filter_after_cube.is_null() ? isl::union_set(set_i) : filter_after_cube.add_set(set_i); + } + } + + isl::union_set_list filters = + isl::union_set_list(node.ctx(), static_cast(scop_info_.analysis_result_.stmt_type_.size() - 1)); + for (const auto &a : filter_before_cube) { + filters = a.is_null() ? filters : filters.add(a); + } + filters = filter_cube.is_null() ? filters : filters.add(filter_cube); + filters = filter_after_cube.is_null() ? filters : filters.add(filter_after_cube); + + if (scop_info_.cube_info_.IsLoad3dL1Ub()) { + node = TileBand(node, sizes, TileType::UB, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, TileType::UB); + } else if ((!filter_before_cube.empty() || !filter_after_cube.is_null()) && !filter_cube.is_null()) { + auto pos = 0; + node = node.insert_sequence(filters); + for (auto a : filter_before_cube) { + node = TileBand(node.child(pos).child(0), sizes, tile_type, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, TileType::UBL1); + node = node.parent().parent(); + ++pos; + } + if (!filter_cube.is_null()) { + node = TileBand(node.child(pos).child(0), sizes, tile_type, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, TileType::L0); + node = node.parent().parent(); + ++pos; + } + if (!filter_after_cube.is_null()) { + node = TileBand(node.child(pos).child(0), sizes, tile_type, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, TileType::UBL0); + node = node.parent().parent(); + ++pos; + } + } else { // Don't insert a sequence node when there is only one filter child + node = TileBand(node, sizes, tile_type, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, tile_type); + } + node = node.parent().parent(); +} + +isl::schedule_node TileOuterBand::TileL0(isl::schedule_node node) { + auto title_size = static_cast(tile_sizes_.size()); + const unsigned int n_member = node.child(0).as().n_member(); + unsigned int dim_num = (n_member <= title_size) ? n_member : title_size; + std::vector ts(n_member, 0); + std::vector full_tile_max(n_member, 0); + for (size_t j = 0; j < n_member; ++j) { + ts[j] = MAX_STRIDE; + full_tile_max[j] = MAX_STRIDE; + if (j < dim_num) { + ts[j] = static_cast(tile_sizes_[j].l0_tiling_size); + auto l1_tiling_size = static_cast(tile_sizes_[j].l1_tiling_size); + auto l0_tiling_size = static_cast(tile_sizes_[j].l0_tiling_size); + if (MAX_STRIDE == l1_tiling_size) continue; + if (MAX_STRIDE == l0_tiling_size) continue; + if ((l1_tiling_size > l0_tiling_size) && (0 != l0_tiling_size)) { + full_tile_max[j] = l1_tiling_size / l0_tiling_size - 1; + } + } + } + node = TileBandAndCollectMark(node.child(0), &ts[0], nullptr, &full_tile_max[0], TileType::L0, true); + return node; +} + +bool TileOuterBand::NeedIsolate() { return scop_info_.cube_info_.IsConv() || scop_info_.cube_info_.IsLoad3dL1Ub(); } + +void TileOuterBand::PaddingIsolate(int &h_head, int &h_tail, int &w_head, int &w_tail) { + h_head = 0; + h_tail = 0; + w_head = 0; + w_tail = 0; + if (scop_info_.cube_info_.GetConvAttrInfo().empty()) return; + int pad_top = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_PAD_TOP); + int pad_bottom = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_PAD_BOTTOM); + int pad_left = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_PAD_LEFT); + int pad_right = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_PAD_RIGHT); + int h = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_FEATURE_H); + int w = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_FEATURE_W); + int kh = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_KERNEL_H); + int kw = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_KERNEL_W); + int stride_h = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_STRIDE_H); + int stride_w = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_STRIDE_W); + int dilation_h = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_DILATION_H); + int dilation_w = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_DILATION_W); + int h_cut = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_TILE_H); + int w_cut = scop_info_.cube_info_.GetAttrValue(ATTR_CONV_TILE_W); + int d_kh = (kh - 1) * dilation_h + 1; + CHECK_NE(stride_h, 0); + int win_h = (h + pad_top + pad_bottom - d_kh) / stride_h + 1; + int win_cut_h = (h_cut - d_kh) / stride_h + 1; + if (win_cut_h > win_h) { + if (!scop_info_.user_config_.GetIsDynamic() || win_h > 0) win_cut_h = win_h; + } + + CHECK_NE(win_cut_h, 0); + int h_base = (win_h + win_cut_h - 1) / win_cut_h; + bool head = (pad_top > 0); + bool tail = ((win_h - 1) * stride_h + d_kh > h + pad_top); + + ComputeHInfo(h_base, head, tail, h_head, h_tail, win_h, win_cut_h); + + int d_kw = (kw - 1) * dilation_w + 1; + CHECK_NE(stride_w, 0); + int win_w = (w + pad_left + pad_right - d_kw) / stride_w + 1; + int win_cut_w = (w_cut - d_kw) / stride_w + 1; + if (win_cut_w > win_w) { + win_cut_w = win_w; + } + + CHECK_NE(win_cut_w, 0); + int w_base = (win_w + win_cut_w - 1) / win_cut_w; + head = (pad_left > 0); + tail = ((win_w - 1) * stride_w + d_kw > w + pad_right); + + ComputeWInfo(w_base, head, tail, w_head, w_tail, win_w, win_cut_w); +} + +void TileOuterBand::ComputeWInfo(int &w_base, bool &head, bool &tail, int &w_head, int &w_tail, int &win_w, + int &win_cut_w) { + const int DIVIDED_PIECES_THREE = 3; + const int DIVIDED_PIECES_TWO = 2; + CHECK_NE(win_cut_w, 0); + if (w_base >= DIVIDED_PIECES_THREE) { + if (head) { + w_head = 1; + if (tail) { + w_tail = w_base - DIVIDED_PIECES_TWO; + } else { + w_tail = win_w / win_cut_w - 1; + } + } else { + w_head = 0; + if (tail) { + w_tail = w_base - DIVIDED_PIECES_TWO; + } else { + w_tail = win_w / win_cut_w - 1; + } + } + } else if (w_base <= DIVIDED_PIECES_TWO) { + if (!head && !tail && win_w / win_cut_w == DIVIDED_PIECES_TWO) { + w_head = 0; + w_tail = 1; + } else if (head && !tail && win_w / win_cut_w == DIVIDED_PIECES_TWO) { + w_head = 1; + w_tail = 1; + } else { + w_head = 0; + w_tail = 0; + } + } +} + +void TileOuterBand::ComputeHInfo(int &h_base, bool &head, bool &tail, int &h_head, int &h_tail, int &win_h, + int &win_cut_h) { + const int DIVIDED_PIECES_THREE = 3; + const int DIVIDED_PIECES_TWO = 2; + CHECK_NE(win_cut_h, 0); + if (h_base >= DIVIDED_PIECES_THREE) { + if (head) { + h_head = 1; + if (tail) { + h_tail = h_base - DIVIDED_PIECES_TWO; + } else { + h_tail = win_h / win_cut_h - 1; + } + } else { + h_head = 0; + if (tail) { + h_tail = h_base - DIVIDED_PIECES_TWO; + } else { + h_tail = win_h / win_cut_h - 1; + } + } + } else if (h_base <= DIVIDED_PIECES_TWO) { + if (!head && !tail && win_h / win_cut_h == DIVIDED_PIECES_TWO) { + h_head = 0; + h_tail = 1; + } else if (head && !tail && win_h / win_cut_h == DIVIDED_PIECES_TWO) { + h_head = 1; + h_tail = 1; + } else { + h_head = 0; + h_tail = 0; + } + } +} + +void TileOuterBand::TileTypeL1(isl::schedule_node &node, int *full_tile_min, int *full_tile_max, TileType &tile_type, + bool &isolate, isl::multi_val &sizes) { + const unsigned int n_member = node.as().n_member(); + auto title_size = static_cast(tile_sizes_.size()); + unsigned int dim_num = (n_member <= title_size) ? n_member : title_size; + std::vector full_tile_max_buf(n_member, 0); + std::vector full_tile_min_buf(n_member, 0); + full_tile_max = &full_tile_max_buf[0]; + full_tile_min = &full_tile_min_buf[0]; + for (size_t j = 0; j < n_member; ++j) { + full_tile_min[j] = 0; + full_tile_max[j] = MAX_STRIDE; + if (!scop_info_.user_config_.GetIsDynamic()) { + if (NeedIsolate() && j < dim_num) { + int h_head, h_tail, w_head, w_tail; + PaddingIsolate(h_head, h_tail, w_head, w_tail); + + if (tile_sizes_[j].axis == "H") { + full_tile_min[j] = h_head; + full_tile_max[j] = h_tail; + } + + if (tile_sizes_[j].axis == "W") { + full_tile_min[j] = w_head; + full_tile_max[j] = w_tail; + } + } + } + } + node = TileBand(node, sizes, tile_type, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, tile_type); + + // L0 tiling + node = TileL0(node.child(0)); +} + +isl::schedule_node TileOuterBand::TileUbL1(isl::schedule_node node) { + const unsigned int n_member = node.child(0).as().n_member(); + unsigned int dim_num = (n_member <= static_cast(tile_sizes_.size())) + ? n_member + : static_cast(tile_sizes_.size()); + std::vector ts(n_member, 0); + std::vector full_tile_max(n_member, 0); + for (size_t j = 0; j < n_member; ++j) { + ts[j] = MAX_STRIDE; + full_tile_max[j] = MAX_STRIDE; + if (j < dim_num) { + ts[j] = static_cast(tile_sizes_[j].l0_tiling_size); + int l1_tiling_size = static_cast(tile_sizes_[j].l1_tiling_size); + int l0_tiling_size = static_cast(tile_sizes_[j].l0_tiling_size); + if (MAX_STRIDE == l1_tiling_size) continue; + if (MAX_STRIDE == l0_tiling_size) continue; + if ((l1_tiling_size > l0_tiling_size) && (0 != l0_tiling_size)) { + full_tile_max[j] = l1_tiling_size / l0_tiling_size - 1; + } + } + } + node = TileBandAndCollectMark(node.child(0), &ts[0], nullptr, &full_tile_max[0], TileType::UBL1, true); + return node; +} + +isl::schedule_node TileOuterBand::TileBandAndCollectMark(isl::schedule_node node, const int *tile_size, + int *full_tile_min, int *full_tile_max, TileType tile_type, + bool isolate) { + isl::multi_val sizes = ComputeBandTilesSizes(node, tile_size); + + if (tile_type == TileType::L1) { + TileTypeL1(node, full_tile_min, full_tile_max, tile_type, isolate, sizes); + } else if (tile_type == TileType::L0) { + TileTypeL0(node, full_tile_min, full_tile_max, tile_type, isolate, sizes); + } else if (tile_type == TileType::L1UBL1) { + node = TileBand(node, sizes, tile_type, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, tile_type); + node = TileUbL1(node.child(0)); + } else if (tile_type == TileType::UBL1) { + node = TileBand(node, sizes, tile_type, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, tile_type); + node = node.parent().parent(); + } else { + node = TileBand(node, sizes, tile_type, full_tile_min, full_tile_max, isolate); + node = MarkTileBand(node, tile_type); + } + return node; +} + +/*************************************************************************** + * steps: + * 1. get tile size. + * 2. tiling + ***************************************************************************/ +isl::schedule_node TileOuterBand::MarkOuterPermutable(isl::schedule_node node) { + // check tilable or not, and return the node if not + if (IsOuterTilable(node) <= 0) return node; + + // make sure the node is a band node and has multiple members, insert empty band if not + if (!node.isa() || (!node.as().member_get_coincident(0) && + scop_info_.user_config_.GetTileCheckCoincident())) + node = InsertEmptyPermutableBand(node); + +#if PRINT_SCHEDULE_INFO + /// print band info + isl::schedule_node_band outer_band = node.as(); + CHECK(!outer_band.is_null()) << " didn't find single outer_band \n" << pass_info_.schedule_; + LOG(INFO) << "Please set dim based on loops band depth: " << outer_band.n_member() << " with " + << outer_band.get_space(); + LOG(INFO) << "Domain info: " << outer_band; +#endif + + const unsigned int n_member = node.as().n_member(); + auto title_size = static_cast(tile_sizes_.size()); + unsigned int dim_num = (n_member <= title_size) ? n_member : title_size; + if (dim_num == 0) { + // direct scalar computation in GM is not allowed, need to promote to UB + return MarkTileBand(node, TileType::UB); + } + + // get tile size + std::vector tile_size(n_member, 0); + for (size_t j = 0; j < n_member; ++j) { + tile_size[j] = MAX_STRIDE; + // tile_size maybe bigger than dim_num + if (j < dim_num) tile_size[j] = static_cast(tile_sizes_[j].l1_tiling_size); + } + + bool isCube = false; + for (auto &info : scop_info_.analysis_result_.GetStmtOpInfoMap()) { + if (info.second.isCube) { + isCube = true; + break; + } + } + + bool is_before_cube = false; + bool is_in_cube = false; + unsigned int i = 0; + for (; i < scop_info_.analysis_result_.stmt_type_.size() - 1; ++i) { + if (scop_info_.analysis_result_.stmt_type_[i].second == STMT_OP_TYPE::CUBE_CONV) { + break; + } + } + bool is_in_load3d = scop_info_.user_config_.GetIsDynamic() ? false : scop_info_.cube_info_.IsLoad3dL1Ub(); + isl::set_list domain_list = node.get_domain().get_set_list(); + for (unsigned int set_index = 0; set_index < domain_list.size(); ++set_index) { + isl::set set_i = domain_list.get_at(set_index); + std::string name = set_i.get_tuple_name(); + if (name.find('_') == std::string::npos) { + LOG(FATAL) << "Cannot find _ symbol"; + } + unsigned int index = WrappedStrtol(name.substr(name.find('_') + 1)); + is_before_cube = false; + if ((index + 1 < i) && !scop_info_.cube_info_.IsSpecGemm()) { + is_before_cube = true; + } + if (index + 1 == i) { + is_in_cube = true; + } + if (scop_info_.user_config_.GetIsDynamic()) { + if (scop_info_.cube_info_.IsLoad3dL1UBStmt(set_i.get_tuple_name())) { + is_in_load3d = true; + } + } + } + + if (isCube && is_before_cube && !is_in_cube) { + node = TileBandAndCollectMark(node, &tile_size[0], nullptr, nullptr, TileType::L1UBL1, true); + } else if (isCube || is_in_load3d) { + node = TileBandAndCollectMark(node, &tile_size[0], nullptr, nullptr, TileType::L1, true); + } else { + node = TileBandAndCollectMark(node, &tile_size[0], nullptr, nullptr, TileType::UB, true); + } + + return node; +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/tile_outer_band.h b/src/poly/schedule_pass/tile_outer_band.h new file mode 100644 index 0000000000000000000000000000000000000000..25b953430ecd18e2bbaba6e51815574f9ba4ba49 --- /dev/null +++ b/src/poly/schedule_pass/tile_outer_band.h @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POLY_TILING_H_ +#define POLY_TILING_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +constexpr auto MAX_STRIDE = 65535; +/* + * Tile the outer band accoding to TilingInfo. In this pass, we get the out-most band, + * decide tile_size depending on the types of operators, and then start tiling. + */ +class TileOuterBand : public SchedulePass { + public: + TileOuterBand(PassInfo &pass_info, ScopInfo &scop_info) : pass_info_(pass_info), scop_info_(scop_info) { + pass_name_ = __FUNCTION__; + }; + ~TileOuterBand() {} + + enum class TileType { + L0 = 0, + L1, + UB, + UBL1, + UBL0, + L1UBL1, + Invalid, + }; + virtual isl::schedule Run(isl::schedule sch); + void InitDimensionInfo(const isl::schedule &); + void MergeTilingInfo(); + std::vector> AddTileInfo(const std::vector> &partition_info); + std::string GetbDim() const { return scop_info_.user_config_.GetBDim(); } + std::string GetcDim(); + + void ShowDimInfo(); + isl::schedule_node MarkOuterPermutable(isl::schedule_node node); + int IsOuterTilable(const isl::schedule_node &node); + int IsCandidate(const isl::schedule_node &node); + bool IsPermutable(const isl::schedule_node &node, bool checkCoincident); + isl::schedule_node InsertEmptyPermutableBand(isl::schedule_node node); + bool SubtreeHasPermutableBands(const isl::schedule_node &node); + isl::schedule_node MarkTileBand(isl::schedule_node node, TileType tile_type); + isl::schedule_node TileBandAndCollectMark(isl::schedule_node node, const int *tile_size, int *full_tile_min, + int *full_tile_max, TileType tile_type, bool isolate); + isl::multi_val ComputeBandTilesSizes(const isl::schedule_node &node, const int *tile_size); + isl::multi_val MultiValFromIntList(const isl::space &space, int dim, const int *list); + void TileTypeL0(isl::schedule_node &node, int *full_tile_min, int *full_tile_max, TileType &tile_type, bool &isolate, + isl::multi_val &sizes); + isl::schedule_node TileBand(isl::schedule_node node, const isl::multi_val &sizes, TileType tile_type, + const int *full_tile_min, const int *full_tile_max, bool isolation); + isl::schedule_node IsolateTiles(const isl::schedule_node &original_node, isl::schedule_node tiled_node, + TileType tile_type, const int *full_tile_min, const int *full_tile_max); + std::pair ComputeFullTile(const isl::schedule_node &original_node, + const isl::schedule_node &tiled_node); + isl::map ComputeTileMap(const isl::schedule_node &original_node, const isl::schedule_node &tiled_node); + void IsolateLevelInfo(TileType &tile_type, isl::set &tiles, isl::set &all); + isl::schedule_node SetIsolateLoopType(isl::schedule_node node); + void TileTypeL1(isl::schedule_node &node, int *full_tile_min, int *full_tile_max, TileType &tile_type, bool &isolate, + isl::multi_val &sizes); + isl::schedule_node TileUbL1(isl::schedule_node node); + isl::schedule_node TileL0(isl::schedule_node node); + void PaddingIsolate(int &h_head, int &h_tail, int &w_head, int &w_tail); + void ComputeHInfo(int &h_base, bool &head, bool &tail, int &h_head, int &h_tail, int &win_h, int &win_cut_h); + void ComputeWInfo(int &w_base, bool &head, bool &tail, int &w_head, int &w_tail, int &win_w, int &win_cut_w); + bool NeedIsolate(); + bool BoolNot(bool b) { return !b; } + + private: + PassInfo &pass_info_; + ScopInfo &scop_info_; + Tiles tiles_; + TileSizes tile_sizes_; + std::vector> partition_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_TILING_H_ \ No newline at end of file diff --git a/src/poly/schedule_pass/transfer_stmt.cc b/src/poly/schedule_pass/transfer_stmt.cc new file mode 100644 index 0000000000000000000000000000000000000000..497622606e81b6a241fb0f19202daf02d7acb8c0 --- /dev/null +++ b/src/poly/schedule_pass/transfer_stmt.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "transfer_stmt.h" + +namespace akg { +namespace ir { +namespace poly { + + +isl::schedule TransferStmt::Run(isl::schedule curr_schedule) { + if (scop_info_.analysis_result_.GetTransferStmt().is_empty()) { + return curr_schedule; + } + pass_info_.transfer_stmt_ = scop_info_.analysis_result_.GetTransferStmt(); + isl::schedule_node root_ = curr_schedule.get_root(); + isl::schedule_node node = GetOuterBand(root_); + if (node.isa() || node.isa()) { + int n = static_cast(node.n_children()); + for (int i = 0; i < n; ++i) { + isl::schedule_node child = node.child(i); + CHECK(child.isa()) << "The child of set or sequence must filter!"; + isl::schedule_node_filter filter_node = child.as(); + isl::union_set filter = filter_node.get_filter(); + if (!filter.intersect(pass_info_.transfer_stmt_).is_empty()) { + filter = filter.subtract(pass_info_.transfer_stmt_); + child = isl::manage(isl_schedule_node_filter_set_filter(child.copy(), filter.copy())); + node = child.parent(); + return node.get_schedule(); + } + } + } + return curr_schedule; +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/transfer_stmt.h b/src/poly/schedule_pass/transfer_stmt.h new file mode 100644 index 0000000000000000000000000000000000000000..8fbc3cef3da95f10169c154f818a8bd59ee7e1a4 --- /dev/null +++ b/src/poly/schedule_pass/transfer_stmt.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_TRANSFER_STMT_H_ +#define POLY_TRANSFER_STMT_H_ + +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Transfer specified statements out of outer-most band's filter nodes if they are previously + * recorded as transfer_stmt. + */ +class TransferStmt : public SchedulePass { + public: + TransferStmt(ScopInfo &scop_info, PassInfo &pass_info) : scop_info_(scop_info), pass_info_(pass_info) { + pass_name_ = __FUNCTION__; + } + ~TransferStmt(){}; + + virtual isl::schedule Run(isl::schedule sch); + + private: + ScopInfo &scop_info_; + PassInfo &pass_info_; +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_TRANSFER_STMT_H_ diff --git a/src/poly/schedule_pass/try_mark_scalar_stmt.cc b/src/poly/schedule_pass/try_mark_scalar_stmt.cc new file mode 100644 index 0000000000000000000000000000000000000000..2f90a0fac1e604ed34a39828e2542f97a1c663b4 --- /dev/null +++ b/src/poly/schedule_pass/try_mark_scalar_stmt.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "try_mark_scalar_stmt.h" +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { +bool TryMarkScalarStmt::SubtreeHasPermutableBands(const isl::schedule_node &node) const { + bool all_non_permutable = false; + auto IsPermutable = [](const isl::schedule_node &node, bool check_coincident) -> bool { + if (!node) return false; + if (!node.isa()) return false; + if (!node.as().get_permutable()) return false; + if (node.as().n_member() < 1) return false; + return !(check_coincident && !node.as().member_get_coincident(0)); + }; + all_non_permutable = node.every_descendant([this, &IsPermutable](const isl::schedule_node &node) -> bool { + return !(IsPermutable(node, pass_info_.tile_check_coincident_)); + }); + return !all_non_permutable; +} + +isl::schedule_node TryMarkScalarStmt::InsertEmptyPermutableBand(isl::schedule_node node) { + isl::space space; + isl::multi_union_pw_aff mupa; + + space = node.get_schedule().get_domain().get_space(); + + space = space.set_from_params(); + mupa = isl::multi_union_pw_aff::zero(space); + node = node.insert_partial_schedule(mupa); + node = node.as().set_permutable(1); + + return node; +} + +isl::schedule TryMarkScalarStmt::Run(isl::schedule curr_schedule) { + const auto &curr_node = curr_schedule.get_root(); + // Return "root" if given an inappropriate node + if (!curr_node.isa() && !curr_node.isa()) return curr_schedule; + // Check whether each stmt is scalar + auto domain = curr_node.isa() ? curr_node.as().get_domain() + : curr_node.as().get_filter(); + if (!domain.every_set([](const isl::set &set) { + auto dim = set.n_dim(); + return dim == 0; + })) + return curr_schedule; + + // Return if there exist any band nodes + if (SubtreeHasPermutableBands(curr_node)) return curr_schedule; + + auto node = GetOuterBand(curr_node); + // Mark to copy to UB + if (node.isa() || (IsSequenceOrSet(node))) { + node = InsertEmptyPermutableBand(node); + auto tag = REALIZE_UB; + node = node.insert_mark(isl::id(node.ctx(), tag)); + return node.get_schedule(); + } + + // Return if none of the above + return curr_schedule; +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass/try_mark_scalar_stmt.h b/src/poly/schedule_pass/try_mark_scalar_stmt.h new file mode 100644 index 0000000000000000000000000000000000000000..1feef51ee99d86d401034a809d33b8d4e0727f08 --- /dev/null +++ b/src/poly/schedule_pass/try_mark_scalar_stmt.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_TRY_MARK_SCALAR_STMT_H_ +#define POLY_TRY_MARK_SCALAR_STMT_H_ + +#include "poly/pass_info.h" +#include "poly/schedule_pass.h" + +namespace akg { +namespace ir { +namespace poly { + +/* + * Mark each scalar statement with a "realize_UB" mark node. "root" should be + * either a domain node or a filter node. + * + * First, check whether each statement in "root" is scalar. Each set of the + * union set represented by "root" represents a statement. We determine a scalar + * statement with "HasNoDims" function, checking whether a give "set" has dims. + * + * Next, check whether the subtree of "root" has permutable bands, and return + * "root" if there are any permutable bands. + * + * Obtain the outermost permutable band, and this would go down to either a leaf + * node or a sequence/set node. + * + * If it comes to a leaf node, "root" represents a single scalar statement. Insert + * an empty band and mark this empty band with a "realize_UB" mark. + * + * If a sequence/set node is encountered, meaning "root" represents multiple + * scalar statements. Mark each child recursively with a "realize_UB" mark. + * + * Return the original "root" in other cases. + */ +class TryMarkScalarStmt : public SchedulePass { + public: + TryMarkScalarStmt(PassInfo &pass_info) : pass_info_(pass_info) { pass_name_ = __FUNCTION__; }; + ~TryMarkScalarStmt(){}; + + virtual isl::schedule Run(isl::schedule sch); + + private: + PassInfo &pass_info_; + + bool SubtreeHasPermutableBands(const isl::schedule_node &node) const; + + isl::schedule_node InsertEmptyPermutableBand(isl::schedule_node node); +}; + +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_TRY_MARK_SCALAR_STMT_H_ diff --git a/src/poly/schedule_pass_mgr.cc b/src/poly/schedule_pass_mgr.cc new file mode 100644 index 0000000000000000000000000000000000000000..2713adaab451bcd8c1360f3378060fb2d778bd28 --- /dev/null +++ b/src/poly/schedule_pass_mgr.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "poly/schedule_pass_mgr.h" + +namespace akg { +namespace ir { +namespace poly { + +const std::vector> &SchedulePassMgr::GetSchedulePasses() const { return schedule_passes_; } + +void SchedulePassMgr::RegisterPass(std::shared_ptrpass) { + CHECK(pass); + schedule_passes_.push_back(pass); +} + +isl::schedule SchedulePassMgr::Run(const isl::schedule &sch) { + CHECK(sch); + return Run(sch, schedule_passes_); +} + +isl::schedule SchedulePassMgr::Run(const isl::schedule &sch, const std::vector> &passes) { + CHECK(sch); + + std::chrono::high_resolution_clock::time_point timer_start; + scop_info_.ClearTimeRecords(); + + auto final_sch = sch; + need_restart_ = false; + + for (auto &pass : passes) { + std::stringstream time_log; + TIMER_START; + final_sch = pass->Run(final_sch); + time_log << "[ Polyhedral exec time" << (scop_info_.cube_info_.IsSpecGemm() ? "_specgemm" : "") << " ], " + << pass->GetPassName() << " spent " << TIMER_DURATION << " ms"; + + LOG(INFO) << time_log.str(); + scop_info_.RecordTime(time_log.str()); + + scop_info_.DumpSchTree(pass->GetPassName(), final_sch); + + if (pass->restart_) { + need_restart_ = true; + break; + } + } + return final_sch; +} + +isl::schedule SchedulePassMgr::Run(const isl::schedule &sch, PassMgrStrategy &strategy) { + CHECK(sch); + strategy.RegisterPasses(); + std::vector> passes = strategy.GetPasses(); + return Run(sch, passes); +} + +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/schedule_pass_mgr.h b/src/poly/schedule_pass_mgr.h new file mode 100644 index 0000000000000000000000000000000000000000..c46f7ca6f982a219537b530877f5e49308f08827 --- /dev/null +++ b/src/poly/schedule_pass_mgr.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POLY_PASS_MGR_H_ +#define POLY_PASS_MGR_H_ + +#include "poly/schedule_pass.h" +#include "poly/pass_mgr_strategy.h" + +namespace akg { +namespace ir { +namespace poly { + +class SchedulePassMgr { + public: + SchedulePassMgr(ScopInfo &scop_info) : scop_info_(scop_info){} + const std::vector> &GetSchedulePasses() const; + void RegisterPass(std::shared_ptrpass); + isl::schedule Run(const isl::schedule &sch); + isl::schedule Run(const isl::schedule &sch, const std::vector> &passes); + isl::schedule Run(const isl::schedule &sch, PassMgrStrategy &strategy); + ~SchedulePassMgr() {} + + bool need_restart_{false}; + ScopInfo &scop_info_; + private: + std::vector> schedule_passes_; +}; +} // namespace poly +} // namespace ir +} // namespace akg +#endif // POLY_PASS_MGR_H_ diff --git a/src/poly/scop.cc b/src/poly/scop.cc index 1fee7fc5f9d02552c9a343265b68392c1dde0903..38fc3c3aa452e251e624dedf70c25fac2b2cca4e 100644 --- a/src/poly/scop.cc +++ b/src/poly/scop.cc @@ -15,48 +15,47 @@ */ #include "poly/scop.h" +#include + #include "poly/scop_builder.h" -#include "poly/transform.h" +#include "poly/poly_util.h" #include "poly/cce_isl_emitter.h" +#include "poly/davinci_mgr_strategy.h" +#include "poly/schedule_pass_mgr.h" namespace akg { namespace ir { namespace poly { -Scop::Scop(Stmt body, const Binds &binds, isl::ctx ctx, bool is_spec_gemm) - : body_(std::move(body)), - binds_(binds), - binds_orig_(binds), - ctx_(ctx), - is_spec_gemm_(is_spec_gemm), - isolated_(false), - isolated_idx_(0) { - if (is_spec_gemm) { - iter_prefix_ = kGemmIterNamePrefix; - } else { - iter_prefix_ = kIterNamePrefix; - } -} - -Scop::~Scop() { - if (model_ != nullptr) { - delete model_; - model_ = nullptr; +void Scop::ParseUserConfig(const Map &attrs, const Map &extern_buffer, + bool is_spec_gemm, bool is_tuning, bool is_dynamic) { + info_.user_config_.SetAttrs(attrs); + info_.user_config_.SetBind(extern_buffer); + info_.user_config_.SetOriginBind(extern_buffer); + info_.user_config_.SetIsTuning(is_tuning); + info_.user_config_.SetDynamic(is_dynamic); + + info_.cube_info_.SetAttrs(attrs); + info_.cube_info_.SetSpecGemm(is_spec_gemm); + if (info_.cube_info_.IsSpecGemm()) { + info_.cube_info_.SetConvAttrInfo(attrs); } } -isl::set Scop::CreateParamsSet() const { - auto space = CreateParamsSpace(ctx_, params_); +isl::set CreateParamsSet(ScopInfo &info) { + auto space = CreateParamsSpace(info.GetCtx(), info.user_config_.GetParams()); auto context = isl::set::universe(space); - - for (const auto ¶m : params_) { - isl::aff aff(isl::aff::param_on_domain(space, isl::id(ctx_, param.second->name_hint))); + auto dynamic_shape = info.user_config_.GetDynamicShape(); + auto params = info.user_config_.GetParams(); + for (const auto ¶m : params) { + isl::aff aff(isl::aff::param_on_domain(space, isl::id(info.GetCtx(), param.second->name_hint))); context = context & (aff > 0); - if (!dynamic_shape_.empty()) { - for (const auto &ds : dynamic_shape_) { - if (auto dsn = ds.as()) { - if (dsn->tensor_name == param.second->name_hint) { - context = context & (aff < dsn->poly_upper_bound); - } + if (dynamic_shape.empty()) { + continue; + } + for (const auto &ds : dynamic_shape) { + if (auto dsn = ds.as()) { + if (dsn->tensor_name == param.second->name_hint) { + context = context & (aff < dsn->poly_upper_bound); } } } @@ -65,15 +64,18 @@ isl::set Scop::CreateParamsSet() const { } isl::schedule Scop::GenIsl() { - body_ = PeelOuterLetStmt(body_, outer_let_stmts_); - - GetParams(); - if (!params_.empty()) { - auto mutator = ConsolidateExprMutator(params_); + auto outer_let_stmts = info_.user_config_.GetOuterLetStmts(); + body_ = PeelOuterLetStmt(body_, outer_let_stmts); + info_.user_config_.SetOuterLetStmts(outer_let_stmts); + info_.user_config_.CollectParams(); + auto params = info_.user_config_.GetParams(); + if (!params.empty()) { + auto mutator = ConsolidateExprMutator(params); body_ = mutator.Mutate(body_); Binds new_binds; - for (auto &it : binds_) { + auto binds = info_.user_config_.GetBind(); + for (auto &it : binds) { Array shape = it.first->shape; for (size_t i = 0; i < shape.size(); ++i) { if (!is_const(shape[i])) { @@ -94,170 +96,58 @@ isl::schedule Scop::GenIsl() { new_binds.Set(t, b); } - binds_ = new_binds; + info_.user_config_.SetBind(new_binds); } - isl::space param_space = CreateParamsSpace(ctx_, params_); - isl::set param_set = CreateParamsSet(); + isl::space param_space = CreateParamsSpace(ctx_, params); + isl::set param_set = CreateParamsSet(info_); - // Make schedule + info_.user_config_.SetBody(body_); Stmt stmt = body_; - isl::schedule schedule_tmp = MakeScheduleTree(param_space, param_set, stmt, *this); + // Make schedule + isl::schedule schedule_tmp = MakeScheduleTree(param_space, param_set, stmt, info_); + + info_.CreateDataFlowInfo(); + info_.cube_info_.UpdateComputeAttrInfo(); + info_.cube_info_.ComputeByPassL1(); return schedule_tmp; } -isl::schedule Scop::Transform(isl::schedule sched, bool coincident, bool tuning) { - auto timer_start = std::chrono::high_resolution_clock::now(); - CreateDataFlowInfo(); - DumpSchTree("00_before_group" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sched); - bool has_group = false; - isl::schedule sch = sched; - if (!disable_group_) { - sch = GroupStatements(sched, has_group); - DumpSchTree("01_after_group" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sch); - } - - // perform polyhedral transformation ongoing, gradually - poly::Transform transform(sch, data_, *this, has_group); - - TIMER_START; - data_.copyin = transform.ComputeCopyIn(); - TIMER_SHOW("computeCopyIn", std::string(is_spec_gemm_ ? "_specgemm" : "")); - - CheckAndRemoveUninitializedCopyin(data_.copyin, binds_orig_); - sch = transform.Initialize(coincident); - - if (outer_band_need_split_ && !is_spec_gemm_) { - sch = SplitOuterBand(sch); - DumpSchTree("06_splitOuterBand" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sch); - } - - TIMER_START; - data_.inter_band_dependency = transform.ComputeFakeCopyin(sch).subtract(data_.copyin); - TIMER_SHOW("computeFakeCopyin", std::string(is_spec_gemm_ ? "_specgemm" : "")); - - if (!is_spec_gemm_ && (IsConv() || IsGemm())) { - this->sch_ = sch.get_map(); - isl::union_map fake_copyin = transform.ComputeFakeCopyin(sch); - ComputeTransferCopyin(fake_copyin); - } - - isl::schedule tiling_sch = sch; - if (!is_spec_gemm_ && !data_.transfer_stmt.is_empty()) { - TransferStmt(tiling_sch); - } - - // get compute attr for conv and load3d op - UpdateComputeAttrInfo(); - if (PRINT_SCHEDULE_INFO) LOG(INFO) << GenHalide(sch); - - // 4. tiling, an initial strategy, pending optimization - Tiles tiles; - if (tuning) { - spaces_ = GenerateTilingSpace(this, sch, dump_tuning_level_, custom_tiling_, dynamic_shape_); - return sch; - } - - TIMER_START; - InitDimensionInfo(tiling_sch); - MergeTilingInfo(tiles); - TIMER_SHOW("AutoTiling", std::string(is_spec_gemm_ ? "_specgemm" : "")); - - if (IsConv()) CreateConvModel(is_dynamic_); - - TIMER_START; - isl::schedule tmp_schedule = transform.TileOuterBand(tiles, sch); - is_tiled_ = true; - TIMER_SHOW("tileOuterBand", std::string(is_spec_gemm_ ? "_specgemm" : "")); - - // for scalar stmt, keep going when coincident = false - if (tmp_schedule.get_map().is_equal(sch.get_map()) && coincident) { - LOG(WARNING) << "same schedule"; - return sched; - } - - if (sch.plain_is_equal(tmp_schedule)) { - tmp_schedule = transform.TryMarkScalarStmts(sch.get_root()).get_schedule(); - } - - sched = tmp_schedule; - DumpSchTree("07_tileOuterBand" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sched); +isl::schedule Scop::Transform(const isl::schedule &input_schedule) { + info_.user_config_.SetConsiderCoincidence(true); + DavinciMgrStrategy davinci_strategy(info_); + SchedulePassMgr mgr(info_); + auto final_schedule = mgr.Run(input_schedule, davinci_strategy); + info_.DumpTransform("davinci_transfrom.log", davinci_strategy.pass_info_); - if (transform.HasInvariantDependence()) { - sched = transform.ReorderInvariantSetSchedule(sched); - DumpSchTree("07_01_reorderAfterTileOuterBand" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sched); + // We offer a restart mechanism for scalar stmt that cannot tile: do not consider coincidence + // and re-compute/re-tile to generate final schedule. + if (mgr.need_restart_) { + info_.user_config_.SetConsiderCoincidence(false); + DavinciMgrStrategy scalar_strategy(info_); + final_schedule = mgr.Run(input_schedule, scalar_strategy); + info_.DumpTransform("scalar_transform.log", scalar_strategy.pass_info_); } - sched = ResetCoincidenceOfReduceAxis(sched, data_.reduce_stmts); - if (pragma_set_all_coincident_) { - sched = transform.SetAllCoincident(sched); - } - // 5. apply intra tile rescheduling - if (!is_dynamic_ || !IsConv()) { - TIMER_START; - transform.IntraTileReschedule(sched, tile_inner_band_, is_spec_gemm_); - TIMER_SHOW("IntraTileRescheduling", std::string(is_spec_gemm_ ? "_specgemm" : "")); - DumpSchTree("08_0_reschedule" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sched); - } - - sched = ReorderInnerBandLoops(sched); - DumpSchTree("08_1_reorderInnerBandLoops" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sched); - sched = ChangeMarkNodePosition(sched); - DumpSchTree("08_2_changeMarkNodePos" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sched); - sched = LabelRealizeOutPosition(sched); - DumpSchTree("08_3_labelAlloc" + std::string(is_spec_gemm_ ? "_specgemm" : ""), sched); - - sched = InsertNodeForAllocC(sched); - - std::vector> partition_info = AddTileInfo(transform.getPartitionInfo()); - AddPartitionInfoToData(partition_info); - - ComputeByPassL1(); - - this->schedule_ = sched; - - TIMER_START; - AddStateTensorsDataFlow(); - ReorderBufferedDefInfos(); - RecordAllTensorBufferFootprintToExtension(); - if (enable_hoist_cond_write_) { - FindConditionalWritePromotions(); - } - TIMER_SHOW("MemoryPromotion", std::string(is_spec_gemm_ ? "_specgemm" : "")); - - if (!is_spec_gemm_ && !data_.transfer_stmt.is_empty()) { - TransferStmt(this->schedule_); - } - DumpSchTree("09_mem_promote" + std::string(is_spec_gemm_ ? "_specgemm" : ""), this->schedule_); - - this->schedule_ = ReorderMarkNodes(this->schedule_); - DumpSchTree("10_reorderMarkNodes" + std::string(is_spec_gemm_ ? "_specgemm" : ""), schedule_); - this->schedule_ = MarkFuseOp(this->schedule_); - DumpSchTree("11_markFuseOp" + std::string(is_spec_gemm_ ? "_specgemm" : ""), this->schedule_); - - // if coincidence constraints are disabled (due to reschedule), we cannot determine multicore axis reliably - bool can_use_multiCore = !is_spec_gemm_ && coincident; - if (can_use_multiCore || enable_mark_multi_core_) { - this->schedule_ = MarkOuterMost(this->schedule_); - DumpSchTree("12_markOuterMost" + std::string(is_spec_gemm_ ? "_specgemm" : ""), this->schedule_); - } - return this->schedule_; + if (final_schedule.get()) info_.analysis_result_.SetTranstormedSchedule(final_schedule); + return final_schedule; } -isl::id_list Scop::CreateIteratorList(const isl::schedule &schedule_iter, const std::string &prefix) { +isl::id_list CreateIteratorList(const isl::schedule &schedule_iter, const std::string &prefix) { + int depth = 0; auto root = schedule_iter.root(); - auto fn = [this](const isl::schedule_node &node) -> isl::schedule_node { + auto fn = [&depth](const isl::schedule_node &node) -> isl::schedule_node { if (node.as()) { - auto depth = static_cast(node.schedule_depth()); - depth = depth + static_cast(node.as().n_member()); - this->depth_ = depth > this->depth_ ? depth : this->depth_; + auto schedule_depth = static_cast(node.schedule_depth()); + schedule_depth = schedule_depth + static_cast(node.as().n_member()); + depth = schedule_depth > depth ? schedule_depth : depth; } return node; }; root = root.map_descendant_bottom_up(fn); - isl::id_list res(root.ctx(), depth_); + isl::id_list res(root.ctx(), depth); - for (int i = 0; i < depth_; ++i) { + for (int i = 0; i < depth; ++i) { std::stringstream ss; ss << prefix << i; res = res.add(isl::id(root.ctx(), ss.str())); @@ -270,9 +160,12 @@ size_t &AstNodeNum() { return n; } constexpr auto AST_NODE_ID_PREFIX = "__node_"; -Stmt Scop::GenHalide(const isl::schedule &schedule_gen) { - // we should check the return value to be isl_stat_ok, but it returns isl_stat_error, so we skip this check. - static_cast(isl_options_set_ast_build_group_coscheduled(schedule_.ctx().get(), isl_bool_true)); +Stmt GenHalide(ScopInfo &info, const isl::schedule &sch, bool used_for_tile_out_band) { + if (!used_for_tile_out_band) { + // we should check the return value to be isl_stat_ok, but it returns isl_stat_error, so we skip this check. + static_cast(isl_options_set_ast_build_group_coscheduled(sch.ctx().get(), isl_bool_true)); + if (info.cube_info_.IsConv()) info.cube_info_.CreateConvModel(); + } NodeInfoRepo node_info_repo; auto gather = [&node_info_repo](const isl::ast_node &node, const isl::ast_build &build) -> isl::ast_node { @@ -294,31 +187,52 @@ Stmt Scop::GenHalide(const isl::schedule &schedule_gen) { }; // set up ast builder - auto builder = isl::ast_build(schedule_gen.ctx()); + auto builder = isl::ast_build(sch.ctx()); builder = builder.set_at_each_domain(gather); - isl::id_list iters = CreateIteratorList(schedule_gen, iter_prefix_); + auto iter_prefix = info.user_config_.GetIterPrefix(info.cube_info_.IsSpecGemm()); + isl::id_list iters = CreateIteratorList(sch, iter_prefix); builder = builder.set_iterators(iters); // build processing std::chrono::high_resolution_clock::time_point timer_start; TIMER_START; - auto ast_node = builder.node_from(schedule_gen); - TIMER_SHOW("NodeFrom", std::string(is_spec_gemm_ ? "_specgemm" : "")); + auto ast_node = builder.node_from(sch); + TIMER_SHOW("NodeFrom", std::string(info.cube_info_.IsSpecGemm() ? "_specgemm" : "")); ast_node = CanonicalizeBlockInAst(ast_node); TIMER_START; - Stmt stmt = CCEIslEmitter(*this, node_info_repo, iters).Emit(ast_node); - TIMER_SHOW("CCEIslEmitter", std::string(is_spec_gemm_ ? "_specgemm" : "")); + Stmt stmt; + if (PRINT_ISL_EMMITER) { + if (used_for_tile_out_band) { + PrintHeader("CCEIslEmitter"); + stmt = CCEIslEmitter(info, node_info_repo, iters).Emit(ast_node); + } else { + PrintHeader("IslEmitter"); + stmt = IslEmitter(info, node_info_repo, iters).Emit(ast_node); + } + } else { + stmt = CCEIslEmitter(info, node_info_repo, iters).Emit(ast_node); + } + + TIMER_SHOW("CCEIslEmitter", std::string(info.cube_info_.IsSpecGemm() ? "_specgemm" : "")); - if (is_dynamic_) { - stmt = RestoreCombinedParams(stmt); + if (PRINT_EMMITER) { + PrintHeader("FINAL SCHEDULE"); + std::cout << PrettyPrintSchTree(sch) << std::endl; + PrintHeader("FINAL ASTNODE"); + std::cout << FormatMupaStr(ast_node.to_str(), false) << std::endl << std::endl; + PrintHeader("FINAL ASTNODE TO C"); + std::cout << ast_node.to_C_str() << std::endl; + PrintHeader("FINAL STMT"); + std::cout << stmt; } return stmt; } -Stmt OptimizeHalide(const Stmt &s, bool dynamic_shape) { return OptimizeCce(s, dynamic_shape); } +Stmt Scop::GenHalide(const isl::schedule &sch) { return poly::GenHalide(info_, sch, false); } + } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/scop.h b/src/poly/scop.h index 8688122827e096b3bfb4d3ae862d267c4d67f63a..51867d8175a99a6c12a7669492a03ce8a7295ee8 100644 --- a/src/poly/scop.h +++ b/src/poly/scop.h @@ -16,509 +16,35 @@ #ifndef POLY_SCOP_H_ #define POLY_SCOP_H_ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "poly/isl.h" -#include "poly/stmt_parse.h" -#include "poly/poly_util.h" -#include "poly/dma_dataflow.h" -#include "poly/custom_tiling.h" -#include "poly/dynamic_shape.h" -#include "pass/convolution_model.h" - -// timer records -#define TIMER_START timer_start = std::chrono::high_resolution_clock::now() -#define TIMER_DURATION \ - (std::chrono::duration_cast>(std::chrono::high_resolution_clock::now() - timer_start) \ - .count()) * \ - 1000 -#define TIMER_SHOW(NAME, SPEC_GEMM) \ - { LOG(INFO) << "[ Polyhedral exec time" << SPEC_GEMM << " ], " << NAME << " spent " << TIMER_DURATION << " ms"; } - -// Prime numbers for prime-param replacement -#define PRIME_1 53 -#define PRIME_2 59 -#define PRIME_3 61 +#include "poly/scop_info.h" +#include "poly/pass_info.h" namespace akg { namespace ir { namespace poly { -class TensorFootprintCluster; - -struct OperatorDomainSpace { - isl::space param_space; - isl::multi_id tuple; -}; - -using IteratorMap = std::unordered_map, isl::IslIdIslHash>; -using StatementMap = std::unordered_map; -using AccessMap = std::unordered_map; -using ReduceMap = std::unordered_map>; -using BufferBindVec = std::vector>; -using OperatorDomainMap = std::unordered_map; -using PartialTileAccessespair = std::vector; -using ReduceStmtMap = std::unordered_map, isl::IslIdIslHash>; -using CondVarsMap = std::unordered_map, isl::IslIdIslHash>; - -struct NodeInfo { - isl::pw_multi_aff iterator_map; - isl::ast_build build; -}; -using NodeInfoRepo = std::unordered_map; - -void GetAffOffsetAndNumVars(const isl::aff &aff, int &offset, int &num_vars); -bool IsAffVarPlusOffset(const isl::aff &aff); -bool IsAffNonZeroConst(const isl::aff &aff); - -std::string TensorMarkTag(MemType memType, MemFlow memFlow); - -Stmt OptimizeHalide(const Stmt &s, bool dynamic_shape = false); - class Scop { public: - struct TilingInfo; - using Binds = Map; - using Tiles = std::vector; - struct ParamInfo { - std::string type_key; - Expr key; - Expr value; - }; - enum AtomicType { Equ = 0, Add }; - - // transform, save group stmts - std::unordered_map group_filter_map_; - // save halide IR let stmts - std::vector outer_let_stmts_; - // save halide IR realize - std::unordered_set realize_from_input_; - Stmt body_; - Binds binds_; - const Binds binds_orig_; - - // dynamic shape - std::unordered_map params_; - std::unordered_map params_rev_map_; - - isl::ctx ctx_; - bool is_spec_gemm_{false}; - bool is_tiled_{false}; - int conv_back_prop_filter_{0}; - int bypassL1_{0}; - int dump_tuning_level_{0}; - bool disable_group_{false}; - bool tile_inner_band_{false}; - bool pragma_set_all_coincident_{false}; - bool remove_self_dependence_{true}; - bool force_remove_self_dependence_{false}; - bool remove_invariant_dependence_{false}; - bool compute_reschedule_{false}; - bool disable_schedule_shift_{false}; - bool enable_schedule_max_constant_{false}; - bool disable_loop_reversal_{false}; - bool disable_loop_fusion_{false}; - bool mod_schedule_shift_{false}; - bool conv_special_dma_{false}; - bool tile_check_coincident_{true}; - bool reorder_schedule_{false}; - bool sink_last_axis_{true}; - bool keep_outer_band_order_{false}; - bool optimize_for_davinci_{false}; - bool enable_feature_library_{false}; - bool enable_hoist_cond_write_{true}; - bool enable_mark_multi_core_{false}; - bool is_dynamic_{false}; - int dump_pass_ir_{0}; - int depth_ = 0; - int dynamic_shape_bound_{0}; - int tile_size_is_var_{0}; - int outer_band_need_split_{0}; - int pragma_is_conv_{0}; - - std::string dump_poly_dir_; - std::string kernel_name_; - std::string iter_prefix_; - isl::schedule schedule_; - isl::union_map sch_; // before tiling, after ungroup. - - std::vector old_l1_write_; - - NodeRef spaces_; - int matB_dim_h_{-1}; - int matB_dim_w_{-1}; - - /// Store related information for analysis - struct Data { - isl::union_map reads; - isl::union_map copyin; - isl::union_map writes; - isl::union_map fake_copyin; - isl::union_set transfer_stmt; - isl::union_map inter_band_dependency; - ReduceStmtMap reduce_stmts; - AccessMap accesses; - StatementMap statements; - StmtOpInfoMap stmt_op_Info; - IteratorMap iterators; - OperatorDomainMap domains; - ReduceMap reduces; - BufferBindVec vecs; - std::vector update_tensors; - std::vector attrs; - - std::vector> range_info; - std::vector range_stride; - } data_; - - std::shared_ptr gemm_a_transpose_fp_cluster_; - std::shared_ptr gemm_b_transpose_fp_cluster_; - std::shared_ptr im2col_fp_cluster; - - // dimension info read from file,erery dimInfo - // represents every row in the file. - struct DimensionInfo { - int64_t index; - std::string axis; - int64_t l1_tiling_size; - int64_t l0_tiling_size; - int64_t dim_seq; - Expr l1_var; - Expr l0_var; - Expr pragma; - bool is_inner{false}; - }; - using TileSizes = std::vector; - - std::vector dim_infos_; - - std::map param_tiling_map_; - - std::map fractal_int_info_; - std::map fractal_str_info_; - - struct TilingInfo { - int tiling_flag; // flag=1, tailing; flag=0, not tailing - std::vector dim_infos; - }; + Scop(Stmt body, isl::ctx ctx) : info_(ScopInfo(ctx)), body_(std::move(body)), ctx_(ctx) {} + ~Scop() = default; - std::vector conv_mnk_dims_; - struct BufferedDecl { - enum Kind { L1, L0, L0A, L0B, L0C, UB }; - - isl::id tensor_id; - std::vector sizes; - Type type; - Kind kind; - Tensor tensor; - }; - - std::map> tensor_name_flows_; - std::map tensor_mem_flows_; - std::vector buffer_def_infos_; - std::queue buffer_footprint_queue_; - BufferDefInfo place_holder_; - std::vector> stmt_type_; - - struct BufferedFootPrintInfo { - std::shared_ptr cluster; - isl::union_map outer_schedule; - isl::id cluster_id; - }; - - std::deque tiling_constraints_; - std::string b_dim_; - - std::unordered_map n_clusters_; - - std::unordered_map buffered_decls_; - - std::vector> active_buffer_footprints_; - - std::unordered_set conditional_write_buffer_footprints_; - - Map attr_info_; - - bool isolated_{false}; - int isolated_idx_{0}; - int out_reduce_init_{0}; - std::vector custom_tiling_; - std::vector dynamic_shape_; - bool dynamic_shape_conv_full_parametric_{false}; - bool pragma_analyze_reuse_buffer_{false}; - bool pragma_speedup_tiling_{false}; - bool pragma_allow_tail_tiling_{true}; - bool pragma_analyze_multicore_{true}; - - ConvolutionModel *model_{nullptr}; - - struct GemmVar { - VarExpr var_batch_name{"b"}; - VarExpr var_no_name{"no"}; - VarExpr var_mo_name{"mo"}; - VarExpr var_mi_name{"mi"}; - VarExpr var_ni_name{"ni"}; - VarExpr var_ko_name{"ko"}; - VarExpr var_ki_name{"ki"}; - }; - - // scop - Scop(Stmt body, const Binds &binds, isl::ctx ctx, bool is_spec_gemm); - ~Scop(); - static std::shared_ptr make(const Stmt &body, const Binds &binds, isl::ctx ctx, bool is_spec_gemm) { - return std::make_shared(body, binds, ctx, is_spec_gemm); - } - - // main + void ParseUserConfig(const Map &attrs, const Map &extern_buffer, + bool is_spec_gemm, bool is_tuning, bool is_dynamic); isl::schedule GenIsl(); - void ComputeTransferCopyin(isl::union_map &fake_copyin); - void TransferStmt(isl::schedule &t_sch); - void ComputeByPassL1(); - void AddPartitionInfoToData(const std::vector> &partition_info); - isl::schedule Transform(isl::schedule, bool coincident = true, bool tuning = false); - Stmt GenHalide(const isl::schedule &); + isl::schedule Transform(const isl::schedule &input_schedule); + Stmt GenHalide(const isl::schedule &sch); - // transform - isl::schedule GroupStatements(const isl::schedule &sch, bool &has_group); - void InitDimensionInfo(const isl::schedule &); - void MergeTilingInfo(Tiles &tiling_infos); - std::unordered_map GetConvInfoForTiling(); - std::vector> AddTileInfo(const std::vector> &partition_info); - isl::schedule ChangeMarkNodePosition(const isl::schedule &); - isl::schedule LabelRealizeOutPosition(const isl::schedule &) const; - isl::schedule ReorderMarkNodes(const isl::schedule &) const; - isl::schedule ReorderInnerBandLoops(const isl::schedule &schedule) const; - bool InjectMulticoreToSchedule(isl::schedule_node &outer_band); - bool SingleMulticoreBand(isl::schedule_node &outer_band); - isl::schedule MarkOuterMost(const isl::schedule &); - isl::schedule MarkFuseOp(const isl::schedule &) const; - isl::schedule InsertNodeForAllocC(isl::schedule &sched); + ScopInfo info_; - // tool - isl::id_list CreateIteratorList(const isl::schedule &schedule_iter, const std::string &prefix); - int ExtractIntFromAttrs(const std::string &name) const; - Expr ExtractExprFromAttrs(const std::string &name) const; - std::string ExtractStringFromAttrs(const std::string &name) const; - std::unordered_set ExtractWithStmtId() const; - std::string ExtractStringFromAttrsAndInfo(const std::string &name) const; - isl::pw_multi_aff RemoveConstOffsetFromBufferFootprint(const isl::pw_multi_aff &promotion); - CondVarsMap ExtractCondVarsMap() const; - - // data info - static bool IsRead(const isl::id &id) { return IsEndsWith(id.get_name(), kReadSuffix); } - static bool IsWrite(const isl::id &id) { return IsEndsWith(id.get_name(), kWriteSuffix); } - static bool IsGMWrite(const isl::id &id) { return id.get_name() == std::string("GMwrite"); } - const isl::union_set Domain() const; - AtomicType GetAtomicWrite(const isl::id &id) const; - Type GetDtypeOf(const std::string &tensor_name) const; - Type GetDtypeOf(const isl::id &var) const { return GetDtypeOf(var.get_name()); } - Type GetDtypeOf(const isl::ast_expr &e) const; - bool IsInBinds(const std::string &name) const; - inline bool IsInBinds(const isl::id &id) const { return IsInBinds(id.get_name()); } - void RecordReduceStmt(const isl::id &stmt_id, const std::vector &reduce_axis_list); - bool InitRangeStrideVec(); - bool MayWriteAfterRead(const std::string &name) const; - bool IsElewiseVMStmt(const isl::id &id) const; - void CreateDataFlowInfo(); - StmtIdHashMap StmtWriteMap(); - StmtIdHashMap StmtReadMap(); - StmtIdHashMap StmtCopyinMap(); - bool IsCopyinTensor(const std::string &tensorName); - void AddTensorDataFlow(const std::vector &mem_flow, const std::vector &name_flow); - void AddStateTensorsDataFlow(); - Tensor FindTensor(const isl::id &var); - Tensor FindTensor(const std::string &str); - Tensor FindTensorInOrig(const isl::id &var); - Tensor FindTensorInOrig(const std::string &str); - Tensor FindTensorWithLargestShape(const isl::id &var); - Tensor FindTensorWithLargestShape(const std::string &str); - void ParseIntAttr(const Map &attrs, const std::string &attr_name, int *attr_to_set); - void ParseBoolAttr(const Map &attrs, const std::string &attr_name, bool *attr_to_set); - void ParseStringAttr(const Map &attrs, const std::string &attr_name, std::string *attr_to_set); - void ParseCustomTilingAttr(const Map &attrs, const std::string &attr_name, - std::vector *attr_to_set); - void ParseDynamicShapeAttr(const Map &attrs, const std::string &attr_name, - std::vector *attr_to_set); - void SetAttrs(const Map &attrs); - std::string GetbDim() const { return b_dim_; } - std::string GetcDim(); - isl::id GetOriginTensorId(const std::string &name) const; - isl::id GetOriginTensorId(const isl::id &id) const; - - // conv info - std::vector GetIsolateVec(int range_idx); - std::vector GetRange(int range_idx); - std::string ConvOutName(); - air::DataType MadCastType(); - bool IsConvHeadTail(const std::string &conv_output, const isl::id &stmtId, const StmtOpInfo &op_info, - const StmtIdHashMap &op_write_map); - bool IsA(const std::string &name) const; - bool IsB(const std::string &name) const; - bool IsC(const std::string &name) const; - bool IsCUB(const std::string &name) const; - std::string GetAName() const; - std::string GetBName() const; - std::string GetCName() const; - bool IsIm2col() const; - bool IsLoad3dL1Ub() const; - bool IsLoad3dL1UBStmt(const std::string &stmtName) const; - bool HasCube() const; - bool IsConv() const; - bool IsConvBackpropInput() const; - bool IsConvBackpropFilter() const; - bool IsGemm() const; - bool IsGemmDataTranspose() const; - bool IsGemmDataTransposeBlock() const; - bool IsGemmDataTransposeInnerBlock() const; - bool IsGemmWeightTranspose() const; - bool IsGemmWeightTransposeBlock() const; - bool IsGemmWeightTransposeInnerBlock() const; - void FindComputeAttr(const std::vector &op_keys); - void UpdateComputeAttrInfo(); - void CreateConvModel(bool is_dynamic); - bool IsFilterCanByPass(); - void GetConvMNKInfo(std::vector &dim_infos); - - // record buffer footprint - bool UpdateBufferDefInfoSizes(const isl::id &tensor_id, const std::vector &new_sizes); - void AddOneBufferDefInfo(const isl::id &ancestorId, const std::vector> &data_stream); - void MakeBufferFootprintCluster(BufferDefInfo &tensor_info); - void GatherBufferFootprintDefInfo(const isl::schedule_node &tree, BufferDefInfo &tensor_info); - void GatherFractalDefInfo(const isl::schedule_node &tree, BufferDefInfo &tensor_info, std::vector &sizes); - void HoistIm2colBufferFootprintCluster(const isl::union_map &schedule, const isl::schedule_node &node, int index, - BufferDefInfo &tensor_info); - void MakeMultiBufferFootprint(const isl::union_map &schedule, const isl::schedule_node &node, int &index, - BufferDefInfo &tensor_info); - void ReorderBufferedDefInfos(); - void RecordAllTensorBufferFootprintToExtension(); - void CollectBufferFootprintDefInfo(BufferDefInfo &tensor_info, const isl::union_map &schedule, - const isl::schedule_node &node); - isl::schedule HoistBufferFootprintAtMarkNode(const isl::schedule_node &root, const std::string &markTag, - size_t index); - isl::schedule_node HoistBufferFootprintAtMarkNode(const isl::schedule_node &tree, size_t index); - isl::schedule_node HoistTensorClusterFootprint(isl::schedule_node tree, size_t index, const isl::union_map &schedule); - std::shared_ptr GetFootPrintsCluster(const isl::id &tensor_id); - const std::vector &BufferDefInfos() const { return buffer_def_infos_; } - bool HasBufferDefInfo(const isl::id &tensor_id) const; - const BufferDefInfo &GetBufferDefInfo(const isl::id &tensor_id) const; - - void UpdateFractalIntFirstInfoConvForward(std::vector im2col_fp_cluster_size, - std::vector fractal_fp_cluster_size); - void UpdateFractalIntFirstInfoConvBackpropFilter(std::vector im2col_fp_cluster_size, - std::vector fractal_fp_cluster_size); - void UpdateFractalIntFirstInfo(bool is_conv_backprop_filter, const std::vector &im2col_fp_cluster_size, - const std::vector &fractal_fp_cluster_size); - void UpdateFractalIntLastInfo(std::vector filter_fp_cluster_size); - void SetFindBuffer(const isl::id &tensor_id, bool find_buffer); - int CountBufferDefInfo(const isl::id &tensor_id) const; - void AddGemmTransposeFpCluster(const isl::union_map &schedule); - const std::vector> &ActiveBufferFootprints() const { - return active_buffer_footprints_; - } - std::vector> CollectBufferedFootprints( - const isl::union_set &active_points, const isl::id &tensor_id) const; - std::vector CollectBufferedFootprintsIndexes(const isl::union_set &active_points, - const isl::id &tensor_id) const; - bool IsWriteWholeBufferFootPrint(const isl::id &poly_ref_id) const; - bool IsConditionalWriteTensor(const std::string &name, - const std::vector> &write_stmts) const; - void FindConditionalWritePromotions(); - - // specgemm - void UpdateSpecGemmFractalInfo(const BufferDefInfo &tensor_info); - Binds BuildConvGemmBand(); - void BuildConvGemmFeatureBand(Scop::Binds &new_bind); - void BuildConvGemmFilterBand(Scop::Binds &new_bind); - void BuildConvGemmResultBand(Scop::Binds &new_bind); - void UpdateFractalIntInfo(int range_idx); - void UpdateFractalIntInfoConvForward(int range_idx); - void UpdateFractalIntInfoConvBackpropFilter(int range_idx); - Stmt ConstructPolyGemm(const Expr &cond = Expr()); - Stmt ConstructGemm(const Binds &gemm_bind, const Expr &cond = Expr()); - Stmt ConstructGemmReduceBody(const Binds &gemm_bind, const Expr &mad_init_cond, const GemmVar &gv); - static Stmt ConstructFor(int init, Expr cond_exp, const VarExpr &iter, const Stmt &s); - std::string ConstructGemmDimensionInfo(); - std::string AutoConstructGemmDimensionInfo(); - void CheckConvGemmParam(); - static int64_t AutoConvMNKTile(const std::string ¶m_name, int64_t param_size); - bool CheckFeatureTensorShape(const Array &shape); - bool CheckFilterTensorShape(const Array &shape); - static Tensor FindBindTensor(const Binds &bind, const std::string &name); - int GetAttrValue(const std::string &key); - int GetMAxisSetDim(); - - // dynamic - void RegisterParam(const Expr &expr); - void GetParams(); - isl::set CreateParamsSet() const; - Stmt RestoreCombinedParams(Stmt stmt); - void InsertRange(std::map ¶m_map, const std::pair &item); - void InsertPairs(Stmt &stmt, std::map ¶m_map); - void InsertPairsConvTileVar(Stmt &stmt, std::map ¶m_map); - void InsertPairsSpecGemmTileVar(std::map ¶m_map); - void InsertPairsSpecGemmOrConv(Stmt &stmt, std::map ¶m_map); - void Full2PartialDynamic(std::unordered_map ¶ms_map, - const Map &attr_info); - Stmt ReplacePrimesWithParameters(Stmt stmt); - Expr ReplacePragmaPrimeByVar(Expr prime); - Stmt AddTilingStrategyApplet(Stmt stmt); - - // debug - void DumpSchTree(const std::string &file_name, const isl::schedule &sch); - bool DumpScopData(const std::string &file_name); - void DumpScopDataBasics(std::ofstream &of); - void DumpScopDataAdvanced(std::ofstream &of); - void DumpScopDataScheduleAttrs(std::ofstream &of); - std::string AddDumpDir(const std::string &file_name); - std::string CreateDumpDir(const std::string &file_name); - void DumpBufferDefInfos(std::ostream &out = LOG(INFO)); -}; - -class PartitionSingle { private: - static PartitionSingle *single_; - static int m_times_; - static int m_cut_m_; - static std::map m_fractal_int_info_; - PartitionSingle(int times, int tile_start, int cut_m, const std::map &fractal_int_info); - ~PartitionSingle() = default; - - public: - static PartitionSingle *CreateInstance(int times, int tile_start, int cut_m, - const std::map &fractal_int_info) { - if (single_ == nullptr) { - single_ = new PartitionSingle(times, tile_start, cut_m, fractal_int_info); - } - return single_; - } - static PartitionSingle *getInstance() { return single_; } - static int getCutM() { return m_cut_m_; } - static int getTimes() { return m_times_; } - static std::map getFractalInfo() { return m_fractal_int_info_; } - - static void free() { - if (single_ != nullptr) { - delete single_; - single_ = nullptr; - } - } + Stmt body_; + isl::ctx ctx_; }; -std::pair, std::deque> GenerateTiling(Scop *scop, - const isl::schedule &, - const std::vector &, - const std::vector &); -isl::union_map ShortSchedule(const isl::schedule_node &node); -isl::union_map LocalSchedule(const isl::schedule_node &node); -NodeRef GenerateTilingSpace(Scop *scop, const isl::schedule &, int dump_level, - const std::vector &custom_tiling, const std::vector &dynamic_shape); -Stmt OptimizeCce(const Stmt &s, bool dynamic_shape = false); +Stmt GenHalide(ScopInfo &info, const isl::schedule &, bool used_for_tile_out_band = false); +Stmt DavinciHalideOptimizer(const Stmt &s, bool dynamic_shape = false); +Stmt RestoreCombinedParams(Stmt stmt, ScopInfo &info); +std::pair> GenerateTiling(const isl::schedule &sch, ScopInfo &scop_info, Stmt body); +NodeRef GenerateTilingSpace(const isl::schedule &sch, ScopInfo &scop_info, Stmt body, int dump_level); } // namespace poly } // namespace ir } // namespace akg diff --git a/src/poly/scop_builder.cc b/src/poly/scop_builder.cc index 5ac14857635ebb84b02b2cf6b2834740faeff2bf..539df9bc8dbe261f96fad3b4a139d8cd5f262083 100644 --- a/src/poly/scop_builder.cc +++ b/src/poly/scop_builder.cc @@ -13,19 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "poly/scop_builder.h" #include #include -#include -#include -#include -#include -#include -#include - #include "pass/utils.h" +#include "construct_poly_accesses.h" + namespace akg { namespace ir { namespace poly { @@ -237,34 +233,6 @@ isl::map AddSuffix4Accesses(AccessMap &accesses, const isl::map &in_map, const N return tensor_map; } -std::pair ConstructPolyAccess(const OperatorDomainSpace &domain, const Node *op, - const std::string &tensor, const Array &dimensions, - AccessMap &accesses) { - // create a tensor coordinate to store the accessed relation - auto coordinate = - CollectTensorCoordinate(domain.param_space, isl::id(domain.param_space.ctx(), tensor), dimensions.size()); - auto tensor_space = coordinate.get_space(); - - // create a fully access set - isl::set tensor_access = isl::set::universe(tensor_space); - - // add access relation constraint for each parameter of one dimension - auto identity = isl::multi_aff::identity(tensor_space.map_from_set()); - for (size_t dim_idx = 0; dim_idx < dimensions.size(); ++dim_idx) { - // make aff bounds of each dimension. - auto domain_aff_bounds = Expr2Aff(domain.param_space, dimensions[dim_idx]); - if (!domain_aff_bounds.is_null()) { - domain_aff_bounds = domain_aff_bounds.unbind_params_insert_domain(coordinate); - tensor_access = tensor_access.intersect(domain_aff_bounds.eq_set(identity.get_aff(static_cast(dim_idx)))); - } - } - - auto tensor_map = - AddSuffix4Accesses(accesses, tensor_access.unbind_params_insert_domain(domain.tuple), op, domain.param_space.ctx()); - - return {tensor_map, isl::map::from(identity)}; -} - isl::union_pw_aff GetUnionPwAffAtDomain(const isl::aff &f, const isl::union_set &domain, const OperatorDomainMap &map) { auto upa = isl::union_pw_aff::empty(domain.space()); for (auto set : domain.get_set_list()) { @@ -273,511 +241,9 @@ isl::union_pw_aff GetUnionPwAffAtDomain(const isl::aff &f, const isl::union_set return upa; } -std::tuple ConstructPolyAccesses(const OperatorDomainSpace &domain, - const Stmt &s, AccessMap &accesses) { - class AttrsExtractor final : public IRVisitor { - public: - AttrsExtractor() {} - ~AttrsExtractor() override = default; - - void Apply(const Stmt &s) { IRVisitor::Visit(s); } - - void Visit_(const AttrStmt *op) override { - if (op->attr_key == ATTR_IM2COL_KEY) { - Map var_map = Downcast>(op->node); - for (auto item : var_map) { - if (item.first == ATTR_PRAGMA_OUT_H) { - m_out_h = item.second.as() != nullptr ? static_cast(item.second.as()->value) : 0; - } else if (item.first == ATTR_PRAGMA_OUT_W) { - m_out_w = item.second.as() != nullptr ? static_cast(item.second.as()->value) : 0; - } - } - } - IRVisitor::Visit_(op); - } - - void Visit_(const Evaluate *op) override { - CHECK(op); - const int im2_col_arg_num = 23; - enum Im2colCallIndex { - idxStrideH = 7, - idxStrideW, - idxKernelH, - idxKernelW, - idxPadTop = 17, - idxPadBottom, - idxPadLeft, - idxPadRight - }; - const Call *call = op->value.as(); - CHECK(call); - auto getCallValue = [&call](const Im2colCallIndex &idx) { - if (auto item = call->args[static_cast(idx)].as()) { - return static_cast(item->value); - } - return 0; - }; - if (call->name == CALL_IM2COL_UB && call->args.size() == im2_col_arg_num) { - m_strid_h = getCallValue(Im2colCallIndex::idxStrideH); - m_strid_w = getCallValue(Im2colCallIndex::idxStrideW); - m_kernel_h = getCallValue(Im2colCallIndex::idxKernelH); - m_kernel_w = getCallValue(Im2colCallIndex::idxKernelW); - m_pad_top = getCallValue(Im2colCallIndex::idxPadTop); - m_pad_bottom = getCallValue(Im2colCallIndex::idxPadBottom); - m_pad_left = getCallValue(Im2colCallIndex::idxPadLeft); - m_pad_right = getCallValue(Im2colCallIndex::idxPadRight); - } - IRVisitor::Visit_(op); - } - - int KernelH() const { return m_kernel_h; } - - int KernelW() const { return m_kernel_w; } - int OutH() const { return m_out_h; } - int OutW() const { return m_out_w; } - int StrideH() const { return m_strid_h; } - int StrideW() const { return m_strid_w; } - int PadLeft() const { return m_pad_left; } - int PadRight() const { return m_pad_right; } - int PadTop() const { return m_pad_top; } - int PadBottom() const { return m_pad_bottom; } - - private: - int m_kernel_h{0}; - int m_kernel_w{0}; - int m_out_h{0}; - int m_out_w{0}; - int m_strid_h{0}; - int m_strid_w{0}; - int m_pad_left{0}; - int m_pad_right{0}; - int m_pad_top{0}; - int m_pad_bottom{0}; - }; - class RelationAccessesParser final : public IRVisitor { - public: - isl::map ExtractIm2ColReadAccess(const std::string &tensor, const Array &shape) { - const int arg_num = shape.size(); - isl::space param_space = domain.param_space; - isl::id tensor_id(param_space.ctx(), tensor); - auto coordinate = CollectTensorCoordinate(param_space, tensor_id, arg_num); - auto tensor_space = coordinate.get_space(); - - isl::set access = isl::set::universe(tensor_space); - auto identity = isl::multi_aff::identity(tensor_space.map_from_set()); - // need to optimize automatic add this exprs - Array args; - auto arg_size = static_cast(param_space.dim(isl_dim_param)); - int k_h = extractor.KernelH(); - int k_w = extractor.KernelW(); - int o_h = extractor.OutH(); - int o_w = extractor.OutW(); - if (arg_size == 3) { - CHECK(shape[0].as()); - args.push_back(shape[0].as()->value > 0 ? static_cast(Var("i")) : Expr(0)); - } else { - args.push_back(VarExpr("j") * Expr(16) / Expr(o_h * o_w)); - } - VarExpr k("k"); - CHECK_GT(k_h, 0); - CHECK_GT(k_w, 0); - Expr v = k / Expr(k_h * k_w); - args.push_back(v); - for (size_t i = 0; i < args.size(); ++i) { - auto range_point = identity.get_aff(static_cast(i)); - auto domain_point = Expr2Aff(param_space, args[i]); - if (!domain_point.is_null()) { - domain_point = domain_point.unbind_params_insert_domain(coordinate); - access = access.intersect(domain_point.eq_set(range_point)); - } - } - auto map = access.unbind_params_insert_domain(domain.tuple); - - std::string tag = "__poly_ref_0"; - isl::id tag_id(domain.param_space.ctx(), tag); - auto domain_space = map.get_space().domain(); - auto tag_space = domain_space.params().add_named_tuple_id_ui(tag_id, 0); - domain_space = domain_space.product(tag_space).unwrap(); - map = map.preimage_domain(isl::multi_aff::domain_map(domain_space)); - enum FeatureMapIndex { kBatchIndex = 0, kC1Index, kHIndex, kWIndex, kC0Index, KFeatureMapSiz }; - - CHECK_EQ(shape.size(), FeatureMapIndex::KFeatureMapSiz); - isl::set range = map.range(); - /*********************** - * no cut in H axis - * 0<= arg2 <= fm_h-1 - * 0<= arg3 <= fm_w-1 - * 0<= arg4 <= 16-1 - ************************/ - if (arg_size == 2) { - range = range.lower_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kBatchIndex), 0); - CHECK(shape[static_cast(FeatureMapIndex::kBatchIndex)].as()); - range = range.upper_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kBatchIndex), - shape[static_cast(FeatureMapIndex::kBatchIndex)].as()->value - 1); - } - CHECK(shape[static_cast(FeatureMapIndex::kHIndex)].as() && - shape[static_cast(FeatureMapIndex::kWIndex)].as() && - shape[static_cast(FeatureMapIndex::kC0Index)].as()); - - range = range.lower_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kHIndex), 0); - range = range.upper_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kHIndex), - shape[static_cast(FeatureMapIndex::kHIndex)].as()->value - 1); - range = range.lower_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kWIndex), 0); - range = range.upper_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kWIndex), - shape[static_cast(FeatureMapIndex::kWIndex)].as()->value - 1); - range = range.lower_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kC0Index), 0); - range = range.upper_bound_si(isl_dim_set, static_cast(FeatureMapIndex::kC0Index), - shape[static_cast(FeatureMapIndex::kC0Index)].as()->value - 1); - - map = map.intersect_range(range); - - return map; - } - - bool UpdateAccess(const Array &shape) const { - const size_t kHIndex = 2; - const int largeHSize = 200; - Expr fm_h = shape[kHIndex]; - if (extractor.PadTop() > 0 && extractor.PadBottom() > 0 && extractor.PadLeft() > 0 && extractor.PadRight() > 0 && - Compare(fm_h, Expr(largeHSize)) > 0) { - return true; - } - return false; - } - - std::string getConstraint(const std::string &min_j, const std::string &max_j, const std::string &min_h, - const std::string &max_h) { - std::ostringstream ss; - ss << "(" << min_j << " <= j <= " << max_j << " and " << min_h << " <= arg2 <= " << max_h << ")"; - std::string set_con = ss.str(); - return set_con; - } - - std::string toString(int i) { - std::ostringstream ss; - ss << i; - return ss.str(); - } - - std::string body(bool left) { - std::ostringstream ss; - if (left) { - ss << extractor.StrideH() << "j/" << extractor.KernelH() << " - " << extractor.PadLeft(); - } else { - ss << extractor.StrideH() << "j/" << extractor.KernelH() << " + " << extractor.PadRight(); - } - return ss.str(); - } - - void UpdatePaddingConstraint(const Expr &fmH) { - int size_h = 0; - if (fmH.as()) { - size_h = static_cast(fmH.as()->value); - } - const int mi = 16; - const int cut_h = 2; - int size_m = extractor.OutH() * extractor.OutW() / mi; - int head_m = cut_h * extractor.OutW() / mi; - - int head_h = extractor.KernelH() + (cut_h - 1) * extractor.StrideH() - extractor.PadTop() - 1; - int tail_h = (extractor.OutH() - cut_h) * extractor.StrideH() - extractor.PadTop(); - - std::string head_con = getConstraint(toString(0), toString(head_m - 1), toString(0), toString(head_h)); - std::string tail_con = - getConstraint(toString(size_m - head_m), toString(size_m - 1), toString(tail_h), toString(size_h - 1)); - std::string body_con = getConstraint(toString(head_m), toString(size_m - head_m - 1), body(true), body(false)); - - auto map_str = reads.to_str(); - std::string constraint = " (" + head_con + " or " + body_con + " or " + tail_con + ") "; - size_t endPos = map_str.find("}"); - std::string main = map_str.substr(0, endPos); - main = main + " and " + constraint + " }"; - isl_union_map *read_tmp = isl_union_map_read_from_str(reads.ctx().get(), main.c_str()); - CHECK(read_tmp); - reads = isl::manage(read_tmp); - } - - isl::map ExtractIm2ColWriteAccess(const std::string &tensor, const Array &shape) { - int arg_num = shape.size(); - isl::space param_space = domain.param_space; - isl::id tensor_id(param_space.ctx(), tensor); - auto coordinate = CollectTensorCoordinate(param_space, tensor_id, arg_num); - auto tensor_space = coordinate.get_space(); - - isl::set access = isl::set::universe(tensor_space); - auto identity = isl::multi_aff::identity(tensor_space.map_from_set()); - // need to optimize automatic add this exprs - auto arg_size = static_cast(param_space.dim(isl_dim_param)); - Array args; - const std::vector consStr5D = {"i", "j", "k", "mi", "ni"}; - const std::vector consStr4D = {"j", "k", "mi", "ni"}; - enum ShapeDim { shape5D = 0, shape4D }; - ShapeDim mod = ShapeDim::shape5D; - if (consStr5D.size() == shape.size()) { - mod = ShapeDim::shape5D; - for (size_t i = 0; i < arg_size; ++i) { - if (i == 0) { - CHECK(shape[0].as()); - Expr e = shape[0].as()->value > 0 ? static_cast(Var(consStr5D[i])) : Expr(0); - args.push_back(e); - } else { - args.push_back(static_cast(Var(consStr5D[i]))); - } - } - } else if (consStr4D.size() == shape.size()) { - mod = ShapeDim ::shape4D; - for (size_t i = 0; i < arg_size; ++i) { - args.push_back(static_cast(Var(consStr4D[i]))); - } - } - - for (size_t i = 0; i < args.size(); ++i) { - auto range_point = identity.get_aff(static_cast(i)); - auto domain_point = Expr2Aff(param_space, args[i]); - if (!domain_point.is_null()) { - domain_point = domain_point.unbind_params_insert_domain(coordinate); - access = access.intersect(domain_point.eq_set(range_point)); - } - } - - auto map = access.unbind_params_insert_domain(domain.tuple); - - std::string tag = "__poly_ref_1"; - isl::id tag_id(domain.param_space.ctx(), tag); - auto domain_space = map.get_space().domain(); - auto tag_space = domain_space.params().add_named_tuple_id_ui(tag_id, 0); - domain_space = domain_space.product(tag_space).unwrap(); - map = map.preimage_domain(isl::multi_aff::domain_map(domain_space)); - - enum FractalIndex { idxBatch = 0, idxMo, idxKo, idxMi, idxKi, fractalSize }; - /*********************** - * mi ni range definition - * 0<= arg3 <= 16-1 - * 0<= arg4 <= 16-1 - ************************/ - CHECK_EQ(shape.size(), FractalIndex::fractalSize - mod); - CHECK(shape[static_cast(FractalIndex::idxMi - mod)].as() && - shape[static_cast(FractalIndex::idxKi - mod)].as()); - isl::set range = map.range(); - - range = range.lower_bound_si(isl_dim_set, static_cast(FractalIndex::idxMi - mod), 0); - range = range.upper_bound_si(isl_dim_set, static_cast(FractalIndex::idxMi - mod), - shape[static_cast(FractalIndex::idxMi - mod)].as()->value - 1); - - range = range.lower_bound_si(isl_dim_set, static_cast(FractalIndex::idxKi - mod), 0); - range = range.upper_bound_si(isl_dim_set, static_cast(FractalIndex::idxKi - mod), - shape[static_cast(FractalIndex::idxKi - mod)].as()->value - 1); - map = map.intersect_range(range); - - return map; - } - - void Visit_(const Evaluate *op) final { - IRVisitor::Visit_(op); - const Call *call_op = op->value.as(); - if (call_op && call_op->name == CALL_IM2COL_UB) { - CHECK_GE(call_op->args.size(), 2); - CHECK(call_op->args[0].as()); - CHECK_GE(call_op->args[0].as()->args.size(), 2); - CHECK(call_op->args[0].as()->args[1].as()); - CHECK(call_op->args[1].as()); - CHECK_GE(call_op->args[1].as()->args.size(), 2); - CHECK(call_op->args[1].as()->args[1].as()); - std::string write_buffer = call_op->args[0].as()->args[1].as()->name_hint; - std::string read_buffer = call_op->args[1].as()->args[1].as()->name_hint; - for (auto item : accesses) { - if (item.first->IsInstance()) { - auto attr = static_cast(item.first); - Array array = Downcast>(attr->node); - Buffer buffer = Downcast(array[0]); - Tensor tensor = Downcast(array[1]); - if (buffer->name == read_buffer) { - isl::map readIm2Col = ExtractIm2ColReadAccess(tensor->op->name, tensor->shape); - reads = reads.unite(readIm2Col); - if (UpdateAccess(tensor->shape)) { - UpdatePaddingConstraint(tensor->shape[2]); - } - } else if (buffer->name == write_buffer) { - isl::map writeIm2Col = ExtractIm2ColWriteAccess(tensor->op->name, tensor->shape); - writes = writes.unite(writeIm2Col); - } - } - } - } - } - - void Visit_(const Call *op) final { - IRVisitor::Visit_(op); - if (op->call_type == Call::Halide) { - isl::map reads_tmp, toinner_tmp; - std::string var_name = op->name; - if (op->func.defined() && op->func->num_outputs() != 1) { - var_name = var_name + "_v" + std::to_string(op->value_index); - } - std::tie(reads_tmp, toinner_tmp) = ConstructPolyAccess(domain, op, var_name, op->args, accesses); - reads = reads.unite(reads_tmp); - to_inner_ = to_inner_.add_map(toinner_tmp); - } - } - - void Visit_(const Provide *op) final { - IRVisitor::Visit_(op); - isl::map writes_tmp, toinner_tmp; - std::string var_name = op->func->func_name(); - if (op->func->num_outputs() != 1) { - var_name = var_name + "_v" + std::to_string(op->value_index); - } - std::tie(writes_tmp, toinner_tmp) = ConstructPolyAccess(domain, op, var_name, op->args, accesses); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.add_map(toinner_tmp); - } - - /* The conditionals of IfThenElse statements may fall in these cases. - * The accesses should be updated to read sets of scop as such accesses - * may only be read. - * - * More complicated cases like conditionals involving Store and/or - * Provide should also update write sets. - */ - void Visit_(const EQ *op) final { - isl::union_map reads_tmp, writes_tmp, toinner_tmp; - - Stmt stmt_a(GetObjPtr(op->a.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - - Stmt stmt_b(GetObjPtr(op->b.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - } - - void Visit_(const NE *op) final { - isl::union_map reads_tmp, writes_tmp, toinner_tmp; - - Stmt stmt_a(GetObjPtr(op->a.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - - Stmt stmt_b(GetObjPtr(op->b.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - } - - void Visit_(const LT *op) final { - isl::union_map reads_tmp, writes_tmp, toinner_tmp; - - Stmt stmt_a(GetObjPtr(op->a.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - - Stmt stmt_b(GetObjPtr(op->b.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - } - - void Visit_(const LE *op) final { - isl::union_map reads_tmp, writes_tmp, toinner_tmp; - - Stmt stmt_a(GetObjPtr(op->a.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - - Stmt stmt_b(GetObjPtr(op->b.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - } - - void Visit_(const GT *op) final { - isl::union_map reads_tmp, writes_tmp, toinner_tmp; - - Stmt stmt_a(GetObjPtr(op->a.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - - Stmt stmt_b(GetObjPtr(op->b.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - } - - void Visit_(const GE *op) final { - isl::union_map reads_tmp, writes_tmp, toinner_tmp; - - Stmt stmt_a(GetObjPtr(op->a.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_a, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - - Stmt stmt_b(GetObjPtr(op->b.get())); - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, stmt_b, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - } - - // End of conditionals of IfThenElse, more cases are pending. - - /* A For type statement may be visited in the presence of - * IfThenElse in the scop, as the body of the enclosing - * if statement. - * - * A Block type should be handled. - */ - - void Visit_(const For *op) final { - IRVisitor::Visit_(op); - isl::union_map reads_tmp, writes_tmp, toinner_tmp; - - std::tie(reads_tmp, writes_tmp, toinner_tmp) = ConstructPolyAccesses(domain, op->body, accesses); - reads = reads.unite(reads_tmp); - writes = writes.unite(writes_tmp); - to_inner_ = to_inner_.unite(toinner_tmp); - } - - const OperatorDomainSpace &domain; - AccessMap &accesses; - - isl::union_map reads, writes; - isl::union_map to_inner_; - AttrsExtractor extractor; - - RelationAccessesParser(const Stmt stmt, const OperatorDomainSpace &space, AccessMap &accesses) - : domain(space), - accesses(accesses), - reads(isl::union_map::empty(domain.tuple.get_space())), - writes(isl::union_map::empty(domain.tuple.get_space())), - to_inner_(isl::union_map::empty(domain.tuple.get_space())) { - extractor.Apply(stmt); - IRVisitor::Visit(stmt); - } - ~RelationAccessesParser() override = default; - } parser(s, domain, accesses); - return std::make_tuple(parser.reads, parser.writes, parser.to_inner_); -} - static const char kStatementLabel[] = "S_"; -bool ParseWithStmt(const Expr &s, const Scop::Data &data) { +bool ParseWithStmt(const Expr &s, const AnalysisResult &result) { class ParseWith final : public IRVisitor { public: void Visit_(const Call *op) final { @@ -795,15 +261,15 @@ bool ParseWithStmt(const Expr &s, const Scop::Data &data) { bool GetResult() const { return find_tensor; } - ParseWith(const Expr &stmt, const Scop::Data &data) { - data.writes.foreach_map([&, this](const isl::map m) -> void { + ParseWith(const Expr &stmt, const AnalysisResult &result) { + result.GetWrites().foreach_map([&, this](const isl::map m) -> void { writes.insert(m.get_tuple_id(isl_dim_out).get_name()); return; }); IRVisitor::Visit(stmt); } ~ParseWith() override = default; - } paserWith(s, data); + } paserWith(s, result); return paserWith.GetResult(); } @@ -845,17 +311,17 @@ std::map call_op_ = { {"vmla", PolyOpType::vmla}, }; -void ParseStmtOpCall(const isl::id &id, const Call *call, Scop::Data &data, const FunctionRef &func) { +void ParseStmtOpCall(const isl::id &id, const Call *call, AnalysisResult &result, const FunctionRef &func) { CHECK(call); if (call->call_type == Call::PureIntrinsic) { if (call_op_.count(call->name) > 0) { - data.stmt_op_Info.at(id).ops.push_back(call_op_[call->name]); + result.GetStmtOpInfoMap().at(id).ops.push_back(call_op_[call->name]); } else if (0 == strcmp(call->name.c_str(), "with")) { - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::with); - if (!data.stmt_op_Info.at(id).isWith) { + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::with); + if (!result.GetStmtOpInfoMap().at(id).isWith) { for (unsigned i = 0; i < call->args.size(); ++i) { - if (ParseWithStmt(call->args[i], data)) { - data.stmt_op_Info.at(id).isWith = true; + if (ParseWithStmt(call->args[i], result)) { + result.GetStmtOpInfoMap().at(id).isWith = true; break; } } @@ -869,22 +335,22 @@ void ParseStmtOpCall(const isl::id &id, const Call *call, Scop::Data &data, cons } else if (0 == strcmp(call->name.c_str(), "sub_relu")) { // do nothing } else if (0 == strcmp(call->name.c_str(), "load3d_l1_ub")) { - data.stmt_op_Info.at(id).isLoad3d = true; - ParseStmtOps(id, call->args[0], data, func); + result.GetStmtOpInfoMap().at(id).isLoad3d = true; + ParseStmtOps(id, call->args[0], result, func); } else if (0 == strcmp(call->name.c_str(), "mad")) { - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::mad); - data.stmt_op_Info.at(id).isCube = true; + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::mad); + result.GetStmtOpInfoMap().at(id).isCube = true; // assign + mad std::string name = id.get_name(); size_t index = static_cast(WrappedStrtol(name.substr(name.length() - 1))); std::string tmp = name.substr(0, name.length() - 1); std::stringstream ss; ss << tmp << index - 1; - if (data.stmt_op_Info.count(isl::id(id.ctx(), ss.str())) > 0 && - data.stmt_op_Info.at(isl::id(id.ctx(), ss.str())).ops[0] == PolyOpType::broadcast) - data.stmt_op_Info.at(isl::id(id.ctx(), ss.str())).isCubeAssign = true; + if (result.GetStmtOpInfoMap().count(isl::id(id.ctx(), ss.str())) > 0 && + result.GetStmtOpInfoMap().at(isl::id(id.ctx(), ss.str())).ops[0] == PolyOpType::broadcast) + result.GetStmtOpInfoMap().at(isl::id(id.ctx(), ss.str())).isCubeAssign = true; // end - data.stmt_op_Info.at(id).C_ = func->func_name(); + result.GetStmtOpInfoMap().at(id).C_ = func->func_name(); CHECK(call->args.size() == 2) << "invalid args of mad! "; auto mul_arg = call->args[0].as() ? call->args[0].as() : call->args[1].as(); @@ -897,18 +363,19 @@ void ParseStmtOpCall(const isl::id &id, const Call *call, Scop::Data &data, cons auto b = mul_arg->b.as(); // in gemm case, C = mad(C, A * B) if (a && b) { - data.stmt_op_Info.at(id).A_ = a->name; - data.stmt_op_Info.at(id).B_ = b->name; + result.GetStmtOpInfoMap().at(id).A_ = a->name; + result.GetStmtOpInfoMap().at(id).B_ = b->name; } // in conv case, reassign A&B by attr if (func.as() != nullptr) { - data.stmt_op_Info.at(id).MadType_ = call->args[1].as() ? call->args[1].as()->type : Float(16); + result.GetStmtOpInfoMap().at(id).MadType_ = + call->args[1].as() ? call->args[1].as()->type : Float(16); for (auto i : func.as()->attrs) { if ("feature" == i.first) { - data.stmt_op_Info.at(id).A_ = i.second.as()->value; + result.GetStmtOpInfoMap().at(id).A_ = i.second.as()->value; } if ("filter" == i.first) { - data.stmt_op_Info.at(id).B_ = i.second.as()->value; + result.GetStmtOpInfoMap().at(id).B_ = i.second.as()->value; } } } @@ -918,121 +385,121 @@ void ParseStmtOpCall(const isl::id &id, const Call *call, Scop::Data &data, cons } } -void ParseStmtOps(const isl::id &id, const Expr &val, Scop::Data &data, const FunctionRef &func) { - data.stmt_op_Info.at(id).isCube = false; - data.stmt_op_Info.at(id).isCubeAssign = false; +void ParseStmtOps(const isl::id &id, const Expr &val, AnalysisResult &result, const FunctionRef &func) { + result.GetStmtOpInfoMap().at(id).isCube = false; + result.GetStmtOpInfoMap().at(id).isCubeAssign = false; if (auto add = val.as()) { if (isImm(add->a) || isImm(add->b)) { if (!isImm(add->a)) { // if add->a is not a scalar, then put it into recursion - ParseStmtOps(id, add->a, data, func); + ParseStmtOps(id, add->a, result, func); } else if (!isImm(add->b)) { // if add->b is not a scalar, then put it into recursion - ParseStmtOps(id, add->b, data, func); + ParseStmtOps(id, add->b, result, func); } else { // if add->a and add->b are both scalar, then report error LOG(FATAL) << "Error: Scalar + Scalar, Please Check."; } - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_single_VS_add); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_single_VS_add); } else { - ParseStmtOps(id, add->a, data, func); - ParseStmtOps(id, add->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_add); + ParseStmtOps(id, add->a, result, func); + ParseStmtOps(id, add->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_add); } } else if (auto sub = val.as()) { - ParseStmtOps(id, sub->a, data, func); - ParseStmtOps(id, sub->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_sub); + ParseStmtOps(id, sub->a, result, func); + ParseStmtOps(id, sub->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_sub); } else if (auto mul = val.as()) { if (isImm(mul->a) || isImm(mul->b)) { // if mul->a is not a scalar, then put it into recursion if (!isImm(mul->a)) { - ParseStmtOps(id, mul->a, data, func); + ParseStmtOps(id, mul->a, result, func); } else if (!isImm(mul->b)) { // if mul->b is not a scalar, then put it into recursion - ParseStmtOps(id, mul->b, data, func); + ParseStmtOps(id, mul->b, result, func); } else { // if mul->a and mul->b are both scalar, then report error LOG(FATAL) << "Error: Scalar + Scalar, Please Check."; } if (isZero(mul->b) || isZero(mul->a)) { - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::broadcast); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::broadcast); } else { - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_single_VS_mul); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_single_VS_mul); } } else { - ParseStmtOps(id, mul->a, data, func); - ParseStmtOps(id, mul->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_mul); + ParseStmtOps(id, mul->a, result, func); + ParseStmtOps(id, mul->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_mul); } } else if (auto f_div = val.as()) { - ParseStmtOps(id, f_div->a, data, func); - ParseStmtOps(id, f_div->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_div); + ParseStmtOps(id, f_div->a, result, func); + ParseStmtOps(id, f_div->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_div); } else if (auto f_mod = val.as()) { - ParseStmtOps(id, f_mod->a, data, func); - ParseStmtOps(id, f_mod->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_mod); + ParseStmtOps(id, f_mod->a, result, func); + ParseStmtOps(id, f_mod->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_mod); } else if (auto div = val.as
()) { - ParseStmtOps(id, div->a, data, func); - ParseStmtOps(id, div->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_div); + ParseStmtOps(id, div->a, result, func); + ParseStmtOps(id, div->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_div); } else if (auto mod = val.as()) { - ParseStmtOps(id, mod->a, data, func); - ParseStmtOps(id, mod->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_mod); + ParseStmtOps(id, mod->a, result, func); + ParseStmtOps(id, mod->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_mod); } else if (auto and_op = val.as()) { - ParseStmtOps(id, and_op->a, data, func); - ParseStmtOps(id, and_op->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_and); + ParseStmtOps(id, and_op->a, result, func); + ParseStmtOps(id, and_op->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_and); } else if (auto or_op = val.as()) { - ParseStmtOps(id, or_op->a, data, func); - ParseStmtOps(id, or_op->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_or); + ParseStmtOps(id, or_op->a, result, func); + ParseStmtOps(id, or_op->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_or); } else if (auto min = val.as()) { - ParseStmtOps(id, min->a, data, func); - ParseStmtOps(id, min->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_min); + ParseStmtOps(id, min->a, result, func); + ParseStmtOps(id, min->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_min); } else if (auto max = val.as()) { - ParseStmtOps(id, max->a, data, func); - ParseStmtOps(id, max->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::elewise_binary_max); + ParseStmtOps(id, max->a, result, func); + ParseStmtOps(id, max->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_max); } else if (auto ge = val.as()) { - ParseStmtOps(id, ge->a, data, func); - ParseStmtOps(id, ge->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::pandora_cmp); + ParseStmtOps(id, ge->a, result, func); + ParseStmtOps(id, ge->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto gt = val.as()) { - ParseStmtOps(id, gt->a, data, func); - ParseStmtOps(id, gt->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::pandora_cmp); + ParseStmtOps(id, gt->a, result, func); + ParseStmtOps(id, gt->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto le = val.as()) { - ParseStmtOps(id, le->a, data, func); - ParseStmtOps(id, le->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::pandora_cmp); + ParseStmtOps(id, le->a, result, func); + ParseStmtOps(id, le->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto lt = val.as()) { - ParseStmtOps(id, lt->a, data, func); - ParseStmtOps(id, lt->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::pandora_cmp); + ParseStmtOps(id, lt->a, result, func); + ParseStmtOps(id, lt->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto eq = val.as()) { - ParseStmtOps(id, eq->a, data, func); - ParseStmtOps(id, eq->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::pandora_cmp); + ParseStmtOps(id, eq->a, result, func); + ParseStmtOps(id, eq->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto ne = val.as()) { - ParseStmtOps(id, ne->a, data, func); - ParseStmtOps(id, ne->b, data, func); - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::pandora_cmp); + ParseStmtOps(id, ne->a, result, func); + ParseStmtOps(id, ne->b, result, func); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if ((isImm(val) || val.type().is_int()) && val.as() == nullptr) { - data.stmt_op_Info.at(id).ops.push_back(PolyOpType::broadcast); + result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::broadcast); } else if (auto sel = val.as