提交 456d200d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!47 enhance vectorization analysis in auto tiling

Merge pull request !47 from yangsijia/enhance-vectorization-analysis
...@@ -311,24 +311,45 @@ void SpaceAnalyzer::IdentifyInsnType() { ...@@ -311,24 +311,45 @@ void SpaceAnalyzer::IdentifyInsnType() {
void SpaceAnalyzer::IdentifyVectorizedAxes() { void SpaceAnalyzer::IdentifyVectorizedAxes() {
if (provides_ana_.empty()) return; if (provides_ana_.empty()) return;
std::string attr_key = "VECTORIZED"; std::string attr_key = "VECTORIZED";
std::unordered_map<std::string, const For *> last_axes; std::unordered_set<std::string> unsupported_insn = {"REDUCE", "TRANSFORM", "TRANSPOSE"};
std::unordered_map<std::string, const For *> mark_dst_axes;
for (auto it : provides_ana_) { for (auto it : provides_ana_) {
std::vector<ProvideEntry> pes = it.second; std::vector<ProvideEntry> pes = it.second;
for (auto pe : pes) { for (auto pe : pes) {
if (pe.src.size() != 1U) continue; bool skip = false;
for (auto insn : unsupported_insn) {
if (pe.basic_op_type.find(insn) != std::string::npos) {
skip = true;
break;
}
}
if (skip) {
continue;
}
Tensor dst_tensor = pe.dst; Tensor dst_tensor = pe.dst;
Tensor src_tensor = pe.src[0]; const For *dst_last = GetBufferInnerAxis(dst_tensor);
if (dst_tensor.var_names.size() != 1U || src_tensor.var_names.size() != 1U) continue; // skip if dst is scalar
const For *dst_last = GetBufferLastAxis(dst_tensor); if (dst_last == nullptr) {
const For *src_last = GetBufferLastAxis(src_tensor); continue;
if (dst_last != nullptr && src_last != nullptr && dst_last == src_last) { }
last_axes[dst_tensor.name] = dst_last;
const For *src_last = nullptr;
for (auto src : pe.src) {
const auto *last = GetBufferInnerAxis(src);
if (last != nullptr && last == dst_last) {
src_last = last;
break;
}
}
// skip if src tensor does not share same inner-most axis with dst tensor
if (src_last == nullptr && !pe.src.empty()) {
continue;
} }
mark_dst_axes[dst_tensor.name] = dst_last;
} }
} }
for (auto la : last_axes) { for (auto la : mark_dst_axes) {
TileAxis *last_axis = analyzer_->Axis(la.second); TileAxis *last_axis = analyzer_->Axis(la.second);
if (last_axis != nullptr) last_axis->MarkWithAttr(AttrInfo{attr_key, la.first}); if (last_axis != nullptr) last_axis->MarkWithAttr(AttrInfo{attr_key, la.first});
} }
...@@ -354,7 +375,7 @@ void SpaceAnalyzer::IdentifyDmaUnderCondition() { ...@@ -354,7 +375,7 @@ void SpaceAnalyzer::IdentifyDmaUnderCondition() {
}; };
ktvm::ir::PostOrderVisit(pe.cond->condition, DetectToT); ktvm::ir::PostOrderVisit(pe.cond->condition, DetectToT);
if (!contain_tot) continue; if (!contain_tot) continue;
TileAxis *tot_axis = analyzer_->Axis(GetBufferLastAxis(pe.dst)); TileAxis *tot_axis = analyzer_->Axis(GetBufferInnerAxis(pe.dst));
if (tot_axis != nullptr) tot_axis->MarkWithAttr(AttrInfo{attr_key, ""}); if (tot_axis != nullptr) tot_axis->MarkWithAttr(AttrInfo{attr_key, ""});
} }
} }
...@@ -371,14 +392,14 @@ void SpaceAnalyzer::IdentifyAlignAxes() { ...@@ -371,14 +392,14 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
std::vector<Tensor> src_tensors = pe.src; std::vector<Tensor> src_tensors = pe.src;
Tensor dst_tensor = pe.dst; Tensor dst_tensor = pe.dst;
if (pe.basic_op_type.find("TRANSPOSE") != std::string::npos) { if (pe.basic_op_type.find("TRANSPOSE") != std::string::npos) {
const For *dst_last = GetBufferLastAxis(dst_tensor); const For *dst_last = GetBufferInnerAxis(dst_tensor);
if (dst_last != nullptr) { if (dst_last != nullptr) {
align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type); align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type);
} else { } else {
analyzer_->RootAxis()->MarkWithAttr(AttrInfo{"TRANSFORM", dst_tensor.name}); analyzer_->RootAxis()->MarkWithAttr(AttrInfo{"TRANSFORM", dst_tensor.name});
} }
} else if (pe.basic_op_type.find("DMA") != std::string::npos) { } else if (pe.basic_op_type.find("DMA") != std::string::npos) {
const For *dst_last = GetBufferLastAxis(dst_tensor); const For *dst_last = GetBufferInnerAxis(dst_tensor);
if (dst_last != nullptr) { if (dst_last != nullptr) {
align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type); align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type);
} else { } else {
...@@ -394,7 +415,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { ...@@ -394,7 +415,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
} }
for (auto t : src_tensors) { for (auto t : src_tensors) {
if (t.loops.size() <= dst_tensor.loops.size()) continue; if (t.loops.size() <= dst_tensor.loops.size()) continue;
const For *src_last = GetBufferLastAxis(t); const For *src_last = GetBufferInnerAxis(t);
if (src_last != nullptr) { if (src_last != nullptr) {
align_axes_attrs[src_last] = std::make_pair(t.name, pe.basic_op_type); align_axes_attrs[src_last] = std::make_pair(t.name, pe.basic_op_type);
} }
...@@ -407,7 +428,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { ...@@ -407,7 +428,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
const Tensor dst_tensor) { const Tensor dst_tensor) {
for (auto src : src_tensors) { for (auto src : src_tensors) {
if (src.name != dst_tensor.name) { if (src.name != dst_tensor.name) {
src_last = GetBufferLastAxis(src); src_last = GetBufferInnerAxis(src);
src_name = src.name; src_name = src.name;
break; break;
} }
...@@ -416,7 +437,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { ...@@ -416,7 +437,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
if (const auto i = src_last->extent.as<IntImm>()) gm_block = i->value; if (const auto i = src_last->extent.as<IntImm>()) gm_block = i->value;
}; };
if (pe.basic_op_type.find("REDUCE") != std::string::npos) { if (pe.basic_op_type.find("REDUCE") != std::string::npos) {
const For *dst_last = GetBufferLastAxis(dst_tensor); const For *dst_last = GetBufferInnerAxis(dst_tensor);
int64_t ub_block = 1; int64_t ub_block = 1;
if (dst_last != nullptr) { if (dst_last != nullptr) {
align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type); align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type);
...@@ -430,7 +451,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { ...@@ -430,7 +451,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
} }
} }
} else if (pe.basic_op_type.find("BROADCAST") != std::string::npos) { } else if (pe.basic_op_type.find("BROADCAST") != std::string::npos) {
const For *dst_last = GetBufferLastAxis(dst_tensor); const For *dst_last = GetBufferInnerAxis(dst_tensor);
int64_t ub_block = 1; int64_t ub_block = 1;
if (dst_last == nullptr) continue; if (dst_last == nullptr) continue;
if (const auto i = dst_last->extent.as<IntImm>()) ub_block = i->value; if (const auto i = dst_last->extent.as<IntImm>()) ub_block = i->value;
...@@ -451,7 +472,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { ...@@ -451,7 +472,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
} }
} }
const For *SpaceAnalyzer::GetBufferLastAxis(Tensor t, int offset) { const For *SpaceAnalyzer::GetBufferInnerAxis(Tensor t, int offset) {
int last_dim = static_cast<int>(t.var_names.size()) - offset; int last_dim = static_cast<int>(t.var_names.size()) - offset;
auto it = t.loops.find(last_dim); auto it = t.loops.find(last_dim);
if (it != t.loops.end() && it->second.size() == 1U) return it->second[0]; if (it != t.loops.end() && it->second.size() == 1U) return it->second[0];
...@@ -470,11 +491,11 @@ void SpaceAnalyzer::IdentifyReduceAxes() { ...@@ -470,11 +491,11 @@ void SpaceAnalyzer::IdentifyReduceAxes() {
std::vector<ProvideEntry> pes = it.second; std::vector<ProvideEntry> pes = it.second;
for (auto pe : pes) { for (auto pe : pes) {
if ((pe.basic_op_type.find("REDUCE") == std::string::npos)) continue; if ((pe.basic_op_type.find("REDUCE") == std::string::npos)) continue;
const For *dst_last = GetBufferLastAxis(pe.dst); const For *dst_last = GetBufferInnerAxis(pe.dst);
if (dst_last == nullptr) { if (dst_last == nullptr) {
// Reduce op like A[i, 0] = A[i, 0] op B[i, j], we need to mark axis `i` as dst last for dma align. // Reduce op like A[i, 0] = A[i, 0] op B[i, j], we need to mark axis `i` as dst last for dma align.
for (auto offset = 0; offset < static_cast<int>(pe.dst.var_names.size()); ++offset) { for (auto offset = 0; offset < static_cast<int>(pe.dst.var_names.size()); ++offset) {
dst_last = GetBufferLastAxis(pe.dst, offset + 1); dst_last = GetBufferInnerAxis(pe.dst, offset + 1);
if (dst_last == nullptr) continue; if (dst_last == nullptr) continue;
MarkAttr(dst_last, "REDUCE_DST_LAST", pe.dst.name); MarkAttr(dst_last, "REDUCE_DST_LAST", pe.dst.name);
break; break;
...@@ -484,7 +505,7 @@ void SpaceAnalyzer::IdentifyReduceAxes() { ...@@ -484,7 +505,7 @@ void SpaceAnalyzer::IdentifyReduceAxes() {
} }
for (Tensor src : pe.src) { for (Tensor src : pe.src) {
if (src.name == pe.dst.name) continue; if (src.name == pe.dst.name) continue;
const For *src_last = GetBufferLastAxis(src); const For *src_last = GetBufferInnerAxis(src);
MarkAttr(src_last, "REDUCE_SRC_LAST", src.name); MarkAttr(src_last, "REDUCE_SRC_LAST", src.name);
std::string flow = src.name + "->" + pe.dst.name; std::string flow = src.name + "->" + pe.dst.name;
root->MarkWithAttr(AttrInfo{"REDUCE_FLOW", flow}); root->MarkWithAttr(AttrInfo{"REDUCE_FLOW", flow});
......
...@@ -67,7 +67,7 @@ class SpaceAnalyzer { ...@@ -67,7 +67,7 @@ class SpaceAnalyzer {
// Provides stmt after analysis. // Provides stmt after analysis.
std::unordered_map<const For *, std::vector<ProvideEntry>> provides_ana_; std::unordered_map<const For *, std::vector<ProvideEntry>> provides_ana_;
const For *GetBufferLastAxis(Tensor t, int offset = 1); const For *GetBufferInnerAxis(Tensor t, int offset = 1);
// generalized cases // generalized cases
void IdentifyInsnType(); void IdentifyInsnType();
void IdentifyDmaUnderCondition(); void IdentifyDmaUnderCondition();
......
...@@ -462,7 +462,9 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem ...@@ -462,7 +462,9 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem
ss << "--> Init factor " << final_factor; ss << "--> Init factor " << final_factor;
auto mod_value = cons.tile_mod_.as<IntImm>() ? cons.tile_mod_.as<IntImm>()->value : 1; auto mod_value = cons.tile_mod_.as<IntImm>() ? cons.tile_mod_.as<IntImm>()->value : 1;
if (static_shape >= mod_value && final_factor % mod_value != 0) { bool is_unaligned = (static_shape >= mod_value && final_factor % mod_value != 0);
bool need_to_align = (final_factor > mod_value || !axis->HasAttr("VECTORIZED"));
if (is_unaligned && need_to_align) {
final_factor = std::max(static_cast<int>(final_factor / mod_value * mod_value), 1); final_factor = std::max(static_cast<int>(final_factor / mod_value * mod_value), 1);
ss << "--> Mod value " << mod_value << " --> Align to mod " << final_factor; ss << "--> Mod value " << mod_value << " --> Align to mod " << final_factor;
} }
...@@ -810,12 +812,15 @@ bool TraverseSolver::IsTilable(TileInfo *info) { ...@@ -810,12 +812,15 @@ bool TraverseSolver::IsTilable(TileInfo *info) {
} }
if (level == LEVEL1) { if (level == LEVEL1) {
min_tile = cons.tile_mod_.as<IntImm>()->value; bool use_tile_min = (info->axis->forbid_iso && const_extent % cons.tile_mod_.as<IntImm>()->value != 0) ||
(cons.tile_min_.as<IntImm>()->value == MIN_TILE) || (axis->HasAttr("VECTORIZED")) ||
if ((info->axis->forbid_iso && const_extent % min_tile != 0) || (cons.tile_min_.as<IntImm>()->value > min_tile) || (cons.tile_min_.as<IntImm>()->value > cons.tile_mod_.as<IntImm>()->value);
(cons.tile_min_.as<IntImm>()->value == MIN_TILE)) { if (use_tile_min) {
min_tile = cons.tile_min_.as<IntImm>()->value; min_tile = cons.tile_min_.as<IntImm>()->value;
} else {
min_tile = cons.tile_mod_.as<IntImm>()->value;
} }
if (axis->range_min > min_tile) { if (axis->range_min > min_tile) {
min_tile = axis->range_min; min_tile = axis->range_min;
} }
......
...@@ -222,20 +222,20 @@ void VectorizedStrategy::AddConstraint() { ...@@ -222,20 +222,20 @@ void VectorizedStrategy::AddConstraint() {
return; return;
} }
for (auto axis : analyzer_->GetAxesOfAttr("VECTORIZED")) { for (auto axis : analyzer_->GetAxesOfAttr("VECTORIZED")) {
if (axis->HasAttr("DYNAMIC_BOUND")) { if (axis->HasAttr("DYNAMIC_BOUND") || axis->range_extent.as<IntImm>() == nullptr) {
continue; continue;
} }
int64_t min_byte = -1; int64_t min_byte = -1;
if (axis->data_size.empty()) { for (const auto &it : axis->data_size) {
min_byte = 1; if (it.second == 0) {
} else { continue;
for (const auto &it : axis->data_size) { }
if (min_byte == -1 || min_byte > it.second) { if (min_byte == -1 || min_byte > it.second) {
min_byte = it.second; min_byte = it.second;
}
} }
} }
CHECK_NE(min_byte, 0); min_byte = min_byte == -1 ? 1 : min_byte;
CHECK_GT(min_byte, 0);
axis->l1_constraints.tile_mod_ = CanonicalSimplify(CastIntToExpr(VECTORIZE_BYTE / min_byte)); axis->l1_constraints.tile_mod_ = CanonicalSimplify(CastIntToExpr(VECTORIZE_BYTE / min_byte));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册