提交 bbaf4bda 编写于 作者: D dabaiji

add dma align strategy in auto tiling that enables multicore

上级 225139a0
...@@ -122,7 +122,7 @@ void TileAxis::InsertL1CandFactor(const Expr &f) { ...@@ -122,7 +122,7 @@ void TileAxis::InsertL1CandFactor(const Expr &f) {
while (i < this->l1_constraints.cand_factor.size()) { while (i < this->l1_constraints.cand_factor.size()) {
if (Equal(this->l1_constraints.cand_factor[i], f)) { if (Equal(this->l1_constraints.cand_factor[i], f)) {
return; return;
} else if (analyzer_->arith_ana_.CanProve(this->l1_constraints.cand_factor[i] > f)) { } else if (analyzer_->arith_ana_.CanProve(this->l1_constraints.cand_factor[i] < f)) {
break; break;
} }
++i; ++i;
...@@ -135,7 +135,7 @@ void TileAxis::InsertL0CandFactor(const Expr &f) { ...@@ -135,7 +135,7 @@ void TileAxis::InsertL0CandFactor(const Expr &f) {
while (i < this->l0_constraints.cand_factor.size()) { while (i < this->l0_constraints.cand_factor.size()) {
if (Equal(this->l0_constraints.cand_factor[i], f)) { if (Equal(this->l0_constraints.cand_factor[i], f)) {
return; return;
} else if (analyzer_->arith_ana_.CanProve(this->l0_constraints.cand_factor[i] > f)) { } else if (analyzer_->arith_ana_.CanProve(this->l0_constraints.cand_factor[i] < f)) {
break; break;
} }
++i; ++i;
...@@ -180,8 +180,13 @@ void TileAxis::DumpAxis(bool on_screen) { ...@@ -180,8 +180,13 @@ void TileAxis::DumpAxis(bool on_screen) {
} }
if (!this->l1_constraints.cand_factor.empty()) { if (!this->l1_constraints.cand_factor.empty()) {
ss << "| L1 Cand_factors:{"; ss << "| L1 Cand_factors:{";
for (const auto &f : this->l1_constraints.cand_factor) { bool full_dump = this->l1_constraints.cand_factor.size() <= 10;
ss << f << ","; if (full_dump) {
for (const auto &f : this->l1_constraints.cand_factor) {
ss << f << ",";
}
} else {
ss << this->l1_constraints.cand_factor[0] << " ... " << this->l1_constraints.cand_factor.back();
} }
ss << "} |"; ss << "} |";
if (on_screen) LOG(INFO) << ss.str(); if (on_screen) LOG(INFO) << ss.str();
...@@ -189,8 +194,13 @@ void TileAxis::DumpAxis(bool on_screen) { ...@@ -189,8 +194,13 @@ void TileAxis::DumpAxis(bool on_screen) {
} }
if (!this->l0_constraints.cand_factor.empty()) { if (!this->l0_constraints.cand_factor.empty()) {
ss << "| L0 Cand_factors:{"; ss << "| L0 Cand_factors:{";
for (const auto &f : this->l0_constraints.cand_factor) { bool full_dump = this->l0_constraints.cand_factor.size() <= 10;
ss << f << ","; if (full_dump) {
for (const auto &f : this->l0_constraints.cand_factor) {
ss << f << ",";
}
} else {
ss << this->l0_constraints.cand_factor[0] << " ... " << this->l0_constraints.cand_factor.back();
} }
ss << "} |"; ss << "} |";
if (on_screen) LOG(INFO) << ss.str(); if (on_screen) LOG(INFO) << ss.str();
...@@ -1310,12 +1320,16 @@ void TilingAnalyzer::AddTilingConstraints() { ...@@ -1310,12 +1320,16 @@ void TilingAnalyzer::AddTilingConstraints() {
actived_strategies.push_back(&pd_attr_strategy); actived_strategies.push_back(&pd_attr_strategy);
CastStrategy cast_strategy(this); CastStrategy cast_strategy(this);
ReduceStrategy reduce_strategy(this);
VectorizedStrategy vectorized_strategy(this); VectorizedStrategy vectorized_strategy(this);
TensorOfTensorStrategy tot_strategy(this); TensorOfTensorStrategy tot_strategy(this);
actived_strategies.push_back(&cast_strategy); actived_strategies.push_back(&cast_strategy);
actived_strategies.push_back(&reduce_strategy);
actived_strategies.push_back(&vectorized_strategy); actived_strategies.push_back(&vectorized_strategy);
if (!is_dynamic_) {
ReduceStrategy reduce_strategy(this);
DmaAlignStrategy dma_align_stratgey(this);
actived_strategies.push_back(&reduce_strategy);
actived_strategies.push_back(&dma_align_stratgey);
}
actived_strategies.push_back(&tot_strategy); actived_strategies.push_back(&tot_strategy);
ModStrategy mod_strategy(this); ModStrategy mod_strategy(this);
...@@ -1512,9 +1526,6 @@ void TilingAnalyzer::DumpBufferUsageTimeable() { ...@@ -1512,9 +1526,6 @@ void TilingAnalyzer::DumpBufferUsageTimeable() {
lived_buf_name.insert(it.first->name); lived_buf_name.insert(it.first->name);
ss << "Buffer " << it.first->name << " | Alloc time: " << alloc_time << " | Last use time : " << last_use_time ss << "Buffer " << it.first->name << " | Alloc time: " << alloc_time << " | Last use time : " << last_use_time
<< " | "; << " | ";
ss << "live buf: [";
for (const auto &name : lived_buf_name) ss << name << ", ";
ss << "]";
logger_.AppendLog(ANA_BUF_LIVE_EXTENT, ss); logger_.AppendLog(ANA_BUF_LIVE_EXTENT, ss);
} }
} }
......
...@@ -21,17 +21,17 @@ namespace akg { ...@@ -21,17 +21,17 @@ namespace akg {
namespace ir { namespace ir {
namespace poly { namespace poly {
void TilingSolver::CollectMemoryLimit() { void TilingSolver::CollectMemoryLimit() {
double percentage = ALLOCATION_PERCENTAGE; percentage_ = ALLOCATION_PERCENTAGE;
for (auto attr : analyzer_.RootAxis()->attrs) { for (auto attr : analyzer_.RootAxis()->attrs) {
if (attr.attr_key != "MEM_RATIO") continue; if (attr.attr_key != "MEM_RATIO") continue;
CHECK_NE(attr.attr_value, ""); CHECK_NE(attr.attr_value, "");
percentage = std::strtod(attr.attr_value.c_str(), nullptr); percentage_ = std::strtod(attr.attr_value.c_str(), nullptr);
break; break;
} }
DavinciInfo &d_info = DavinciInfo::GetInstance(); DavinciInfo &d_info = DavinciInfo::GetInstance();
for (auto i = 0; i < MEM_SCOPE_BULK; ++i) { for (auto i = 0; i < MEM_SCOPE_BULK; ++i) {
this->mem_limit_[i] = d_info.GetMemoryLimitInScope(i) * percentage; this->mem_limit_[i] = d_info.GetMemoryLimitInScope(i) * percentage_;
} }
} }
...@@ -431,9 +431,7 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem ...@@ -431,9 +431,7 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem
TileAxis::Constraint cons = level == LEVEL1 ? axis->l1_constraints : axis->l0_constraints; TileAxis::Constraint cons = level == LEVEL1 ? axis->l1_constraints : axis->l0_constraints;
if (!cons.cand_factor.empty()) { if (!cons.cand_factor.empty()) {
for (auto i = static_cast<int>(cons.cand_factor.size()) - 1; i >= 0; --i) { for (auto max_cand : cons.cand_factor) {
auto max_cand = cons.cand_factor[i];
if (max_cand.as<IntImm>() == nullptr) { if (max_cand.as<IntImm>() == nullptr) {
ss << "Static shape should have const candidate factor, while got " << max_cand; ss << "Static shape should have const candidate factor, while got " << max_cand;
analyzer_.logger_.LogFatalAndSaveLog(ss.str()); analyzer_.logger_.LogFatalAndSaveLog(ss.str());
...@@ -443,6 +441,12 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem ...@@ -443,6 +441,12 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem
final_factor = max_cand.as<IntImm>()->value; final_factor = max_cand.as<IntImm>()->value;
ss << "--> Candidate factor " << final_factor; ss << "--> Candidate factor " << final_factor;
break; break;
} else if (max_cand.as<IntImm>()->value * exceed_ratio_ * percentage_ < static_mem_constraint) {
final_factor = max_cand.as<IntImm>()->value;
exceed_ratio_ =
exceed_ratio_ * (static_cast<double>(final_factor) / static_cast<double>(static_mem_constraint));
ss << "--> Candidate factor " << final_factor << " (exceed ratio update to " << exceed_ratio_ << ")";
break;
} }
} }
} else { } else {
...@@ -542,11 +546,11 @@ void InequalitySolver::CalculateMemoryInBuffer(const TilingAnalyzer::BufferEntry ...@@ -542,11 +546,11 @@ void InequalitySolver::CalculateMemoryInBuffer(const TilingAnalyzer::BufferEntry
ExprSimplifier().CanProveWithPosParam(mem_info->live_size[buf->scope] < mem_info->max_live_size[buf->scope]); ExprSimplifier().CanProveWithPosParam(mem_info->live_size[buf->scope] < mem_info->max_live_size[buf->scope]);
if (current_is_larger) { if (current_is_larger) {
ss << "Can prove current live size" << mem_info->live_size[buf->scope] << " greater than maximal size " ss << "[Update max] current live size " << mem_info->live_size[buf->scope] << " greater than maximal size "
<< mem_info->max_live_size[buf->scope]; << mem_info->max_live_size[buf->scope];
mem_info->max_live_size[buf->scope] = mem_info->live_size[buf->scope]; mem_info->max_live_size[buf->scope] = mem_info->live_size[buf->scope];
} else if (!current_is_smaller) { } else if (!current_is_smaller) {
ss << "Can not compare current live size" << mem_info->live_size[buf->scope] << " with maximal size " ss << "[Unknown] Can not compare current live size" << mem_info->live_size[buf->scope] << " with maximal size "
<< mem_info->max_live_size[buf->scope]; << mem_info->max_live_size[buf->scope];
mem_info->max_live_size[buf->scope] = CanonicalSimplify(mem_info->max_live_size[buf->scope] + buf_shape); mem_info->max_live_size[buf->scope] = CanonicalSimplify(mem_info->max_live_size[buf->scope] + buf_shape);
} }
...@@ -620,6 +624,7 @@ void InequalitySolver::UpdateMemInfo() { ...@@ -620,6 +624,7 @@ void InequalitySolver::UpdateMemInfo() {
} }
void InequalitySolver::UpdateMemInfoWithBufReuse() { void InequalitySolver::UpdateMemInfoWithBufReuse() {
std::stringstream ss;
auto mem_info = tiling_mem_info_.get(); auto mem_info = tiling_mem_info_.get();
CHECK(mem_info); CHECK(mem_info);
...@@ -631,6 +636,7 @@ void InequalitySolver::UpdateMemInfoWithBufReuse() { ...@@ -631,6 +636,7 @@ void InequalitySolver::UpdateMemInfoWithBufReuse() {
continue; continue;
} }
if (mem_info->live_size[it.first->scope].defined() && mem_info->live_buf[it.first].defined()) { if (mem_info->live_size[it.first->scope].defined() && mem_info->live_buf[it.first].defined()) {
ss << "Release buffer " << it.first->name << " with size " << mem_info->live_buf[it.first];
mem_info->live_size[it.first->scope] -= mem_info->live_buf[it.first]; mem_info->live_size[it.first->scope] -= mem_info->live_buf[it.first];
} }
mem_info->live_buf.erase(it.first); mem_info->live_buf.erase(it.first);
......
...@@ -36,6 +36,8 @@ class TilingSolver { ...@@ -36,6 +36,8 @@ class TilingSolver {
TileCandidate cand_; TileCandidate cand_;
int64_t mem_limit_[MEM_SCOPE_BULK]{0}; int64_t mem_limit_[MEM_SCOPE_BULK]{0};
int tiling_band_{0}; int tiling_band_{0};
double percentage_ = 0.5;
double exceed_ratio_ = 1; // allow memory allocation to exceed memory_size * percentage, may disable double buffer
}; };
class InequalitySolver : TilingSolver { class InequalitySolver : TilingSolver {
......
...@@ -213,17 +213,7 @@ void CastStrategy::AddConstraint() { ...@@ -213,17 +213,7 @@ void CastStrategy::AddConstraint() {
void ReduceStrategy::AddConstraint() { void ReduceStrategy::AddConstraint() {
for (auto axis : analyzer_->GetAxesOfAttr("REDUCE_DST_LAST")) { for (auto axis : analyzer_->GetAxesOfAttr("REDUCE_DST_LAST")) {
int64_t block_size = GetMaxAlignBytes(axis->data_size); axis->l1_constraints.tile_min_ = CastInt64ToExpr(GetMaxAlignBytes(axis->data_size));
int64_t const_extent = axis->GetConstExtent();
if (const_extent == -1) {
continue;
}
int64_t align_elem = ktvm::ir::gcd(block_size, const_extent);
if (align_elem == block_size) {
axis->l1_constraints.tile_min_ = align_elem;
} else {
axis->forbid_iso = true;
}
} }
} }
...@@ -250,6 +240,35 @@ void VectorizedStrategy::AddConstraint() { ...@@ -250,6 +240,35 @@ void VectorizedStrategy::AddConstraint() {
} }
} }
void DmaAlignStrategy::AddConstraint() {
for (auto axis : analyzer_->GetAxesContainsAttr("ALIGN")) {
for (const auto &attr : axis->attrs) {
LOG(INFO) << attr.attr_key;
if ((attr.attr_key.find("ALIGN") == std::string::npos) || (attr.attr_key.find("DMA") == std::string::npos)) {
continue;
}
auto align_size = GetMaxAlignBytes(axis->data_size);
int const_extent = axis->GetConstExtent();
// For dynamic shape or axes that has other candidates, simply add tile min constraint;
// for static shape that has no other candidate, add aligned candidates.
if (const_extent == -1 || !axis->l1_constraints.cand_factor.empty()) {
axis->l1_constraints.tile_min_ = CastInt64ToExpr(align_size);
} else {
std::vector<ktvm::Expr> candidates;
for (auto cand = const_extent; cand >= align_size; --cand) {
auto tail = const_extent % cand;
if (tail == 0 || tail >= align_size) {
candidates.emplace_back(CastIntToExpr(cand));
}
}
axis->l1_constraints.cand_factor = candidates;
}
}
}
}
void TensorOfTensorStrategy::AddConstraint() { void TensorOfTensorStrategy::AddConstraint() {
for (auto axis : analyzer_->GetAxesOfAttr("TOT")) { for (auto axis : analyzer_->GetAxesOfAttr("TOT")) {
if (!axis->HasAttr("ALIGN:DMA")) continue; if (!axis->HasAttr("ALIGN:DMA")) continue;
......
...@@ -112,13 +112,20 @@ class VectorizedStrategy : public TilingStrategy { ...@@ -112,13 +112,20 @@ class VectorizedStrategy : public TilingStrategy {
void AddConstraint(); void AddConstraint();
}; };
class DmaAlignStrategy : public TilingStrategy {
public:
explicit DmaAlignStrategy(const TilingAnalyzer *a) : TilingStrategy(a) {}
~DmaAlignStrategy() {}
void AddConstraint();
std::string interested_attr_key = "ALIGN";
};
class TensorOfTensorStrategy : public TilingStrategy { class TensorOfTensorStrategy : public TilingStrategy {
public: public:
explicit TensorOfTensorStrategy(const TilingAnalyzer *a) : TilingStrategy(a) {} explicit TensorOfTensorStrategy(const TilingAnalyzer *a) : TilingStrategy(a) {}
~TensorOfTensorStrategy() {} ~TensorOfTensorStrategy() {}
void AddConstraint(); void AddConstraint();
std::string interested_attr_key = "CAST";
}; };
class PassDownAttrStrategy : public TilingStrategy { class PassDownAttrStrategy : public TilingStrategy {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册