diff --git a/src/pass/inject_thread_bind.cc b/src/pass/inject_thread_bind.cc index 1843bcdd2f385a3ace5452cb10869c2595f70300..4840ea586d9a1fdccfce4b7102eabb9e58038168 100644 --- a/src/pass/inject_thread_bind.cc +++ b/src/pass/inject_thread_bind.cc @@ -1371,18 +1371,53 @@ class MultiCoreScalarMerge : public IRMutator { class ScalarPeel : public IRMutator { public: - Stmt Run(const Stmt &s) { return Mutate(s); } + Stmt Run(const Stmt &s) { + Stmt res = Mutate(s); + if (!before_scalar_store_) { + multi_core_body_ = s; + return Stmt(); + } + return res; + } + Stmt multi_core_body_; + bool find_multi_core_{false}; + bool before_scalar_store_{false}; private: + Stmt Mutate_(const Store *op, const Stmt &s) final { + if (!find_multi_core_) before_scalar_store_ = true; + return IRMutator::Mutate_(op, s); + } + + bool MultiCoreAttr(const AttrStmt *op) { + if (op->attr_key == "pragma_multi_core_depth" && + Compare(op->value, make_const(op->value.type(), 1)) == 0) { + return true; + } + return false; + } + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if (op->attr_key == "pragma_multi_core_depth" && Compare(op->value, make_const(op->value.type(), 1)) == 0) { + if (MultiCoreAttr(op)) { + find_multi_core_ = true; Stmt body = Mutate(op->body); multi_core_body_ = AttrStmt::make(op->node, op->attr_key, op->value, body); return AttrStmt::make(op->node, op->attr_key, op->value, Evaluate::make(0)); } return IRMutator::Mutate_(op, s); } + + Stmt Mutate_(const Block *op, const Stmt &s) final { + if (op->first.defined() && op->rest.defined() && + op->first.as() != nullptr && MultiCoreAttr(op->first.as())) { + auto first = Mutate(op->first); + auto rest = Mutate(op->rest); + multi_core_body_ = Block::make(multi_core_body_, rest); + return Block::make(first, Evaluate::make(0)); + } + return IRMutator::Mutate_(op, s); + } }; Stmt InjectMultiCore(Stmt stmt, int max_block_dim, int merge_outer_loop, bool is_dynamic, bool scalar_rearrange) { @@ -1393,7 +1428,7 @@ Stmt InjectMultiCore(Stmt stmt, int max_block_dim, int merge_outer_loop, bool is Stmt scalar_part; if (scalar_rearrange) { ScalarPeel peel; - scalar_part = peel.Mutate(stmt); + scalar_part = peel.Run(stmt); stmt = peel.multi_core_body_; } @@ -1419,7 +1454,7 @@ Stmt InjectMultiCore(Stmt stmt, int max_block_dim, int merge_outer_loop, bool is stmt = MultiCoreInsert(plan.block_num_, plan.block_coef_).Insert(stmt); } stmt = LoopUnCompunder().Mutate(stmt); - if (scalar_rearrange) { + if (scalar_rearrange && scalar_part.defined()) { stmt = MultiCoreScalarMerge().Run(stmt, scalar_part); } return stmt; diff --git a/src/poly/cce_isl_emitter.cc b/src/poly/cce_isl_emitter.cc index ed1dc8895c836a45de04c95a3e0663aa49db3f46..d15af812b36eea85d0b55de87030a06ae956776e 100644 --- a/src/poly/cce_isl_emitter.cc +++ b/src/poly/cce_isl_emitter.cc @@ -491,14 +491,15 @@ bool CCEIslEmitter::InjectMulticore(const std::string &iter) { bool is_loop_in_multicore_band = (coincident_member < multicore_info.coincidence.size()); if (is_loop_in_multicore_band) { should_insert_multi_core = multicore_info.coincidence[coincident_member]; + if (should_insert_multi_core) { + ++multicore_info.multicore_depth; + --multicore_info.coincidence[coincident_member]; + } } } else { LOG(WARNING) << "multicore: unrecognized loop var " << iter; } } - if (should_insert_multi_core) { - ++multicore_info.multicore_depth; - } return should_insert_multi_core; } @@ -549,8 +550,8 @@ Stmt CCEIslEmitter::EmitFor(const isl::ast_node_for &node) { if (should_insert_multi_core) { CHECK_EQ(multicore_info.multicore_depth, original_multicore_info.multicore_depth + 1); stmt = AttrStmt::make(make_zero(Int(32)), "pragma_multi_core_depth", Expr(multicore_info.multicore_depth), stmt); + --multicore_info.multicore_depth; } - multicore_info = original_multicore_info; return stmt; }