提交 c0954c77 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!48 add dma align strategy in auto tiling that enables multicore

Merge pull request !48 from yangsijia/add-dma-align-strategy
无相关合并请求
......@@ -122,7 +122,7 @@ void TileAxis::InsertL1CandFactor(const Expr &f) {
while (i < this->l1_constraints.cand_factor.size()) {
if (Equal(this->l1_constraints.cand_factor[i], f)) {
return;
} else if (analyzer_->arith_ana_.CanProve(this->l1_constraints.cand_factor[i] > f)) {
} else if (analyzer_->arith_ana_.CanProve(this->l1_constraints.cand_factor[i] < f)) {
break;
}
++i;
......@@ -135,7 +135,7 @@ void TileAxis::InsertL0CandFactor(const Expr &f) {
while (i < this->l0_constraints.cand_factor.size()) {
if (Equal(this->l0_constraints.cand_factor[i], f)) {
return;
} else if (analyzer_->arith_ana_.CanProve(this->l0_constraints.cand_factor[i] > f)) {
} else if (analyzer_->arith_ana_.CanProve(this->l0_constraints.cand_factor[i] < f)) {
break;
}
++i;
......@@ -180,8 +180,13 @@ void TileAxis::DumpAxis(bool on_screen) {
}
if (!this->l1_constraints.cand_factor.empty()) {
ss << "| L1 Cand_factors:{";
for (const auto &f : this->l1_constraints.cand_factor) {
ss << f << ",";
bool full_dump = this->l1_constraints.cand_factor.size() <= 10;
if (full_dump) {
for (const auto &f : this->l1_constraints.cand_factor) {
ss << f << ",";
}
} else {
ss << this->l1_constraints.cand_factor[0] << " ... " << this->l1_constraints.cand_factor.back();
}
ss << "} |";
if (on_screen) LOG(INFO) << ss.str();
......@@ -189,8 +194,13 @@ void TileAxis::DumpAxis(bool on_screen) {
}
if (!this->l0_constraints.cand_factor.empty()) {
ss << "| L0 Cand_factors:{";
for (const auto &f : this->l0_constraints.cand_factor) {
ss << f << ",";
bool full_dump = this->l0_constraints.cand_factor.size() <= 10;
if (full_dump) {
for (const auto &f : this->l0_constraints.cand_factor) {
ss << f << ",";
}
} else {
ss << this->l0_constraints.cand_factor[0] << " ... " << this->l0_constraints.cand_factor.back();
}
ss << "} |";
if (on_screen) LOG(INFO) << ss.str();
......@@ -1310,12 +1320,16 @@ void TilingAnalyzer::AddTilingConstraints() {
actived_strategies.push_back(&pd_attr_strategy);
CastStrategy cast_strategy(this);
ReduceStrategy reduce_strategy(this);
VectorizedStrategy vectorized_strategy(this);
TensorOfTensorStrategy tot_strategy(this);
actived_strategies.push_back(&cast_strategy);
actived_strategies.push_back(&reduce_strategy);
actived_strategies.push_back(&vectorized_strategy);
if (!is_dynamic_) {
ReduceStrategy reduce_strategy(this);
DmaAlignStrategy dma_align_stratgey(this);
actived_strategies.push_back(&reduce_strategy);
actived_strategies.push_back(&dma_align_stratgey);
}
actived_strategies.push_back(&tot_strategy);
ModStrategy mod_strategy(this);
......@@ -1512,9 +1526,6 @@ void TilingAnalyzer::DumpBufferUsageTimeable() {
lived_buf_name.insert(it.first->name);
ss << "Buffer " << it.first->name << " | Alloc time: " << alloc_time << " | Last use time : " << last_use_time
<< " | ";
ss << "live buf: [";
for (const auto &name : lived_buf_name) ss << name << ", ";
ss << "]";
logger_.AppendLog(ANA_BUF_LIVE_EXTENT, ss);
}
}
......
......@@ -21,17 +21,17 @@ namespace akg {
namespace ir {
namespace poly {
void TilingSolver::CollectMemoryLimit() {
double percentage = ALLOCATION_PERCENTAGE;
percentage_ = ALLOCATION_PERCENTAGE;
for (auto attr : analyzer_.RootAxis()->attrs) {
if (attr.attr_key != "MEM_RATIO") continue;
CHECK_NE(attr.attr_value, "");
percentage = std::strtod(attr.attr_value.c_str(), nullptr);
percentage_ = std::strtod(attr.attr_value.c_str(), nullptr);
break;
}
DavinciInfo &d_info = DavinciInfo::GetInstance();
for (auto i = 0; i < MEM_SCOPE_BULK; ++i) {
this->mem_limit_[i] = d_info.GetMemoryLimitInScope(i) * percentage;
this->mem_limit_[i] = d_info.GetMemoryLimitInScope(i) * percentage_;
}
}
......@@ -431,9 +431,7 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem
TileAxis::Constraint cons = level == LEVEL1 ? axis->l1_constraints : axis->l0_constraints;
if (!cons.cand_factor.empty()) {
for (auto i = static_cast<int>(cons.cand_factor.size()) - 1; i >= 0; --i) {
auto max_cand = cons.cand_factor[i];
for (auto max_cand : cons.cand_factor) {
if (max_cand.as<IntImm>() == nullptr) {
ss << "Static shape should have const candidate factor, while got " << max_cand;
analyzer_.logger_.LogFatalAndSaveLog(ss.str());
......@@ -443,6 +441,12 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem
final_factor = max_cand.as<IntImm>()->value;
ss << "--> Candidate factor " << final_factor;
break;
} else if (max_cand.as<IntImm>()->value * exceed_ratio_ * percentage_ < static_mem_constraint) {
final_factor = max_cand.as<IntImm>()->value;
exceed_ratio_ =
exceed_ratio_ * (static_cast<double>(final_factor) / static_cast<double>(static_mem_constraint));
ss << "--> Candidate factor " << final_factor << " (exceed ratio update to " << exceed_ratio_ << ")";
break;
}
}
} else {
......@@ -542,11 +546,11 @@ void InequalitySolver::CalculateMemoryInBuffer(const TilingAnalyzer::BufferEntry
ExprSimplifier().CanProveWithPosParam(mem_info->live_size[buf->scope] < mem_info->max_live_size[buf->scope]);
if (current_is_larger) {
ss << "Can prove current live size" << mem_info->live_size[buf->scope] << " greater than maximal size "
ss << "[Update max] current live size " << mem_info->live_size[buf->scope] << " greater than maximal size "
<< mem_info->max_live_size[buf->scope];
mem_info->max_live_size[buf->scope] = mem_info->live_size[buf->scope];
} else if (!current_is_smaller) {
ss << "Can not compare current live size" << mem_info->live_size[buf->scope] << " with maximal size "
ss << "[Unknown] Can not compare current live size" << mem_info->live_size[buf->scope] << " with maximal size "
<< mem_info->max_live_size[buf->scope];
mem_info->max_live_size[buf->scope] = CanonicalSimplify(mem_info->max_live_size[buf->scope] + buf_shape);
}
......@@ -620,6 +624,7 @@ void InequalitySolver::UpdateMemInfo() {
}
void InequalitySolver::UpdateMemInfoWithBufReuse() {
std::stringstream ss;
auto mem_info = tiling_mem_info_.get();
CHECK(mem_info);
......@@ -631,6 +636,7 @@ void InequalitySolver::UpdateMemInfoWithBufReuse() {
continue;
}
if (mem_info->live_size[it.first->scope].defined() && mem_info->live_buf[it.first].defined()) {
ss << "Release buffer " << it.first->name << " with size " << mem_info->live_buf[it.first];
mem_info->live_size[it.first->scope] -= mem_info->live_buf[it.first];
}
mem_info->live_buf.erase(it.first);
......
......@@ -36,6 +36,8 @@ class TilingSolver {
TileCandidate cand_;
int64_t mem_limit_[MEM_SCOPE_BULK]{0};
int tiling_band_{0};
double percentage_ = 0.5;
double exceed_ratio_ = 1; // allow memory allocation to exceed memory_size * percentage, may disable double buffer
};
class InequalitySolver : TilingSolver {
......
......@@ -213,17 +213,7 @@ void CastStrategy::AddConstraint() {
void ReduceStrategy::AddConstraint() {
for (auto axis : analyzer_->GetAxesOfAttr("REDUCE_DST_LAST")) {
int64_t block_size = GetMaxAlignBytes(axis->data_size);
int64_t const_extent = axis->GetConstExtent();
if (const_extent == -1) {
continue;
}
int64_t align_elem = ktvm::ir::gcd(block_size, const_extent);
if (align_elem == block_size) {
axis->l1_constraints.tile_min_ = align_elem;
} else {
axis->forbid_iso = true;
}
axis->l1_constraints.tile_min_ = CastInt64ToExpr(GetMaxAlignBytes(axis->data_size));
}
}
......@@ -250,6 +240,35 @@ void VectorizedStrategy::AddConstraint() {
}
}
void DmaAlignStrategy::AddConstraint() {
for (auto axis : analyzer_->GetAxesContainsAttr("ALIGN")) {
for (const auto &attr : axis->attrs) {
LOG(INFO) << attr.attr_key;
if ((attr.attr_key.find("ALIGN") == std::string::npos) || (attr.attr_key.find("DMA") == std::string::npos)) {
continue;
}
auto align_size = GetMaxAlignBytes(axis->data_size);
int const_extent = axis->GetConstExtent();
// For dynamic shape or axes that has other candidates, simply add tile min constraint;
// for static shape that has no other candidate, add aligned candidates.
if (const_extent == -1 || !axis->l1_constraints.cand_factor.empty()) {
axis->l1_constraints.tile_min_ = CastInt64ToExpr(align_size);
} else {
std::vector<ktvm::Expr> candidates;
for (auto cand = const_extent; cand >= align_size; --cand) {
auto tail = const_extent % cand;
if (tail == 0 || tail >= align_size) {
candidates.emplace_back(CastIntToExpr(cand));
}
}
axis->l1_constraints.cand_factor = candidates;
}
}
}
}
void TensorOfTensorStrategy::AddConstraint() {
for (auto axis : analyzer_->GetAxesOfAttr("TOT")) {
if (!axis->HasAttr("ALIGN:DMA")) continue;
......
......@@ -112,13 +112,20 @@ class VectorizedStrategy : public TilingStrategy {
void AddConstraint();
};
class DmaAlignStrategy : public TilingStrategy {
public:
explicit DmaAlignStrategy(const TilingAnalyzer *a) : TilingStrategy(a) {}
~DmaAlignStrategy() {}
void AddConstraint();
std::string interested_attr_key = "ALIGN";
};
class TensorOfTensorStrategy : public TilingStrategy {
public:
explicit TensorOfTensorStrategy(const TilingAnalyzer *a) : TilingStrategy(a) {}
~TensorOfTensorStrategy() {}
void AddConstraint();
std::string interested_attr_key = "CAST";
};
class PassDownAttrStrategy : public TilingStrategy {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部