diff --git a/src/poly/tiling_analyzer.cc b/src/poly/tiling_analyzer.cc index ae2f2a6cde170e665c3127a45e820cbee995d748..2a2a2a05e658c325d184f0bb938a4008a63364f5 100644 --- a/src/poly/tiling_analyzer.cc +++ b/src/poly/tiling_analyzer.cc @@ -122,7 +122,7 @@ void TileAxis::InsertL1CandFactor(const Expr &f) { while (i < this->l1_constraints.cand_factor.size()) { if (Equal(this->l1_constraints.cand_factor[i], f)) { return; - } else if (analyzer_->arith_ana_.CanProve(this->l1_constraints.cand_factor[i] > f)) { + } else if (analyzer_->arith_ana_.CanProve(this->l1_constraints.cand_factor[i] < f)) { break; } ++i; @@ -135,7 +135,7 @@ void TileAxis::InsertL0CandFactor(const Expr &f) { while (i < this->l0_constraints.cand_factor.size()) { if (Equal(this->l0_constraints.cand_factor[i], f)) { return; - } else if (analyzer_->arith_ana_.CanProve(this->l0_constraints.cand_factor[i] > f)) { + } else if (analyzer_->arith_ana_.CanProve(this->l0_constraints.cand_factor[i] < f)) { break; } ++i; @@ -180,8 +180,13 @@ void TileAxis::DumpAxis(bool on_screen) { } if (!this->l1_constraints.cand_factor.empty()) { ss << "| L1 Cand_factors:{"; - for (const auto &f : this->l1_constraints.cand_factor) { - ss << f << ","; + bool full_dump = this->l1_constraints.cand_factor.size() <= 10; + if (full_dump) { + for (const auto &f : this->l1_constraints.cand_factor) { + ss << f << ","; + } + } else { + ss << this->l1_constraints.cand_factor[0] << " ... " << this->l1_constraints.cand_factor.back(); } ss << "} |"; if (on_screen) LOG(INFO) << ss.str(); @@ -189,8 +194,13 @@ void TileAxis::DumpAxis(bool on_screen) { } if (!this->l0_constraints.cand_factor.empty()) { ss << "| L0 Cand_factors:{"; - for (const auto &f : this->l0_constraints.cand_factor) { - ss << f << ","; + bool full_dump = this->l0_constraints.cand_factor.size() <= 10; + if (full_dump) { + for (const auto &f : this->l0_constraints.cand_factor) { + ss << f << ","; + } + } else { + ss << this->l0_constraints.cand_factor[0] << " ... " << this->l0_constraints.cand_factor.back(); } ss << "} |"; if (on_screen) LOG(INFO) << ss.str(); @@ -1310,12 +1320,16 @@ void TilingAnalyzer::AddTilingConstraints() { actived_strategies.push_back(&pd_attr_strategy); CastStrategy cast_strategy(this); - ReduceStrategy reduce_strategy(this); VectorizedStrategy vectorized_strategy(this); TensorOfTensorStrategy tot_strategy(this); actived_strategies.push_back(&cast_strategy); - actived_strategies.push_back(&reduce_strategy); actived_strategies.push_back(&vectorized_strategy); + if (!is_dynamic_) { + ReduceStrategy reduce_strategy(this); + DmaAlignStrategy dma_align_stratgey(this); + actived_strategies.push_back(&reduce_strategy); + actived_strategies.push_back(&dma_align_stratgey); + } actived_strategies.push_back(&tot_strategy); ModStrategy mod_strategy(this); @@ -1512,9 +1526,6 @@ void TilingAnalyzer::DumpBufferUsageTimeable() { lived_buf_name.insert(it.first->name); ss << "Buffer " << it.first->name << " | Alloc time: " << alloc_time << " | Last use time : " << last_use_time << " | "; - ss << "live buf: ["; - for (const auto &name : lived_buf_name) ss << name << ", "; - ss << "]"; logger_.AppendLog(ANA_BUF_LIVE_EXTENT, ss); } } diff --git a/src/poly/tiling_solver.cc b/src/poly/tiling_solver.cc index 7ff4c9810d27a259a50d48979c747d7104c54b76..f3888aa319d1fd965210baf32ca331ba497f0fe3 100644 --- a/src/poly/tiling_solver.cc +++ b/src/poly/tiling_solver.cc @@ -21,17 +21,17 @@ namespace akg { namespace ir { namespace poly { void TilingSolver::CollectMemoryLimit() { - double percentage = ALLOCATION_PERCENTAGE; + percentage_ = ALLOCATION_PERCENTAGE; for (auto attr : analyzer_.RootAxis()->attrs) { if (attr.attr_key != "MEM_RATIO") continue; CHECK_NE(attr.attr_value, ""); - percentage = std::strtod(attr.attr_value.c_str(), nullptr); + percentage_ = std::strtod(attr.attr_value.c_str(), nullptr); break; } DavinciInfo &d_info = DavinciInfo::GetInstance(); for (auto i = 0; i < MEM_SCOPE_BULK; ++i) { - this->mem_limit_[i] = d_info.GetMemoryLimitInScope(i) * percentage; + this->mem_limit_[i] = d_info.GetMemoryLimitInScope(i) * percentage_; } } @@ -431,9 +431,7 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem TileAxis::Constraint cons = level == LEVEL1 ? axis->l1_constraints : axis->l0_constraints; if (!cons.cand_factor.empty()) { - for (auto i = static_cast(cons.cand_factor.size()) - 1; i >= 0; --i) { - auto max_cand = cons.cand_factor[i]; - + for (auto max_cand : cons.cand_factor) { if (max_cand.as() == nullptr) { ss << "Static shape should have const candidate factor, while got " << max_cand; analyzer_.logger_.LogFatalAndSaveLog(ss.str()); @@ -443,6 +441,12 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem final_factor = max_cand.as()->value; ss << "--> Candidate factor " << final_factor; break; + } else if (max_cand.as()->value * exceed_ratio_ * percentage_ < static_mem_constraint) { + final_factor = max_cand.as()->value; + exceed_ratio_ = + exceed_ratio_ * (static_cast(final_factor) / static_cast(static_mem_constraint)); + ss << "--> Candidate factor " << final_factor << " (exceed ratio update to " << exceed_ratio_ << ")"; + break; } } } else { @@ -542,11 +546,11 @@ void InequalitySolver::CalculateMemoryInBuffer(const TilingAnalyzer::BufferEntry ExprSimplifier().CanProveWithPosParam(mem_info->live_size[buf->scope] < mem_info->max_live_size[buf->scope]); if (current_is_larger) { - ss << "Can prove current live size" << mem_info->live_size[buf->scope] << " greater than maximal size " + ss << "[Update max] current live size " << mem_info->live_size[buf->scope] << " greater than maximal size " << mem_info->max_live_size[buf->scope]; mem_info->max_live_size[buf->scope] = mem_info->live_size[buf->scope]; } else if (!current_is_smaller) { - ss << "Can not compare current live size" << mem_info->live_size[buf->scope] << " with maximal size " + ss << "[Unknown] Can not compare current live size" << mem_info->live_size[buf->scope] << " with maximal size " << mem_info->max_live_size[buf->scope]; mem_info->max_live_size[buf->scope] = CanonicalSimplify(mem_info->max_live_size[buf->scope] + buf_shape); } @@ -620,6 +624,7 @@ void InequalitySolver::UpdateMemInfo() { } void InequalitySolver::UpdateMemInfoWithBufReuse() { + std::stringstream ss; auto mem_info = tiling_mem_info_.get(); CHECK(mem_info); @@ -631,6 +636,7 @@ void InequalitySolver::UpdateMemInfoWithBufReuse() { continue; } if (mem_info->live_size[it.first->scope].defined() && mem_info->live_buf[it.first].defined()) { + ss << "Release buffer " << it.first->name << " with size " << mem_info->live_buf[it.first]; mem_info->live_size[it.first->scope] -= mem_info->live_buf[it.first]; } mem_info->live_buf.erase(it.first); diff --git a/src/poly/tiling_solver.h b/src/poly/tiling_solver.h index a193a51f54d45cf46cb6597b33f5380a31e81ac7..54bd6ede1f32b3bac9ac3ec7b17959f4504d9338 100644 --- a/src/poly/tiling_solver.h +++ b/src/poly/tiling_solver.h @@ -36,6 +36,8 @@ class TilingSolver { TileCandidate cand_; int64_t mem_limit_[MEM_SCOPE_BULK]{0}; int tiling_band_{0}; + double percentage_ = 0.5; + double exceed_ratio_ = 1; // allow memory allocation to exceed memory_size * percentage, may disable double buffer }; class InequalitySolver : TilingSolver { diff --git a/src/poly/tiling_strategy_manager.cc b/src/poly/tiling_strategy_manager.cc index da393e523ce21a2e3a4d647423af48432671d697..5f12276d687b2f3699561d80bb0a95df182c957b 100644 --- a/src/poly/tiling_strategy_manager.cc +++ b/src/poly/tiling_strategy_manager.cc @@ -213,17 +213,7 @@ void CastStrategy::AddConstraint() { void ReduceStrategy::AddConstraint() { for (auto axis : analyzer_->GetAxesOfAttr("REDUCE_DST_LAST")) { - int64_t block_size = GetMaxAlignBytes(axis->data_size); - int64_t const_extent = axis->GetConstExtent(); - if (const_extent == -1) { - continue; - } - int64_t align_elem = ktvm::ir::gcd(block_size, const_extent); - if (align_elem == block_size) { - axis->l1_constraints.tile_min_ = align_elem; - } else { - axis->forbid_iso = true; - } + axis->l1_constraints.tile_min_ = CastInt64ToExpr(GetMaxAlignBytes(axis->data_size)); } } @@ -250,6 +240,35 @@ void VectorizedStrategy::AddConstraint() { } } +void DmaAlignStrategy::AddConstraint() { + for (auto axis : analyzer_->GetAxesContainsAttr("ALIGN")) { + for (const auto &attr : axis->attrs) { + LOG(INFO) << attr.attr_key; + if ((attr.attr_key.find("ALIGN") == std::string::npos) || (attr.attr_key.find("DMA") == std::string::npos)) { + continue; + } + auto align_size = GetMaxAlignBytes(axis->data_size); + + int const_extent = axis->GetConstExtent(); + + // For dynamic shape or axes that has other candidates, simply add tile min constraint; + // for static shape that has no other candidate, add aligned candidates. + if (const_extent == -1 || !axis->l1_constraints.cand_factor.empty()) { + axis->l1_constraints.tile_min_ = CastInt64ToExpr(align_size); + } else { + std::vector candidates; + for (auto cand = const_extent; cand >= align_size; --cand) { + auto tail = const_extent % cand; + if (tail == 0 || tail >= align_size) { + candidates.emplace_back(CastIntToExpr(cand)); + } + } + axis->l1_constraints.cand_factor = candidates; + } + } + } +} + void TensorOfTensorStrategy::AddConstraint() { for (auto axis : analyzer_->GetAxesOfAttr("TOT")) { if (!axis->HasAttr("ALIGN:DMA")) continue; diff --git a/src/poly/tiling_strategy_manager.h b/src/poly/tiling_strategy_manager.h index 2330f1cd7ed9c2f3bb99270c1601d0ea432d63c0..397557419a2bf2db8be944b0abea0d46b62b36d1 100644 --- a/src/poly/tiling_strategy_manager.h +++ b/src/poly/tiling_strategy_manager.h @@ -112,13 +112,20 @@ class VectorizedStrategy : public TilingStrategy { void AddConstraint(); }; +class DmaAlignStrategy : public TilingStrategy { + public: + explicit DmaAlignStrategy(const TilingAnalyzer *a) : TilingStrategy(a) {} + ~DmaAlignStrategy() {} + void AddConstraint(); + + std::string interested_attr_key = "ALIGN"; +}; + class TensorOfTensorStrategy : public TilingStrategy { public: explicit TensorOfTensorStrategy(const TilingAnalyzer *a) : TilingStrategy(a) {} ~TensorOfTensorStrategy() {} void AddConstraint(); - - std::string interested_attr_key = "CAST"; }; class PassDownAttrStrategy : public TilingStrategy {