diff --git a/src/poly/space_analyzer.cc b/src/poly/space_analyzer.cc index dedb5927c6c9a250795c96ec6fcf1bccd429bf47..6b8a4d1203ce7af8dff604b20170fa5b4d41ec33 100644 --- a/src/poly/space_analyzer.cc +++ b/src/poly/space_analyzer.cc @@ -311,24 +311,45 @@ void SpaceAnalyzer::IdentifyInsnType() { void SpaceAnalyzer::IdentifyVectorizedAxes() { if (provides_ana_.empty()) return; std::string attr_key = "VECTORIZED"; - std::unordered_map last_axes; - + std::unordered_set unsupported_insn = {"REDUCE", "TRANSFORM", "TRANSPOSE"}; + std::unordered_map mark_dst_axes; for (auto it : provides_ana_) { std::vector pes = it.second; for (auto pe : pes) { - if (pe.src.size() != 1U) continue; + bool skip = false; + for (auto insn : unsupported_insn) { + if (pe.basic_op_type.find(insn) != std::string::npos) { + skip = true; + break; + } + } + if (skip) { + continue; + } Tensor dst_tensor = pe.dst; - Tensor src_tensor = pe.src[0]; - if (dst_tensor.var_names.size() != 1U || src_tensor.var_names.size() != 1U) continue; - const For *dst_last = GetBufferLastAxis(dst_tensor); - const For *src_last = GetBufferLastAxis(src_tensor); - if (dst_last != nullptr && src_last != nullptr && dst_last == src_last) { - last_axes[dst_tensor.name] = dst_last; + const For *dst_last = GetBufferInnerAxis(dst_tensor); + // skip if dst is scalar + if (dst_last == nullptr) { + continue; + } + + const For *src_last = nullptr; + for (auto src : pe.src) { + const auto *last = GetBufferInnerAxis(src); + if (last != nullptr && last == dst_last) { + src_last = last; + break; + } + } + // skip if src tensor does not share same inner-most axis with dst tensor + if (src_last == nullptr && !pe.src.empty()) { + continue; } + mark_dst_axes[dst_tensor.name] = dst_last; } } - for (auto la : last_axes) { + for (auto la : mark_dst_axes) { TileAxis *last_axis = analyzer_->Axis(la.second); if (last_axis != nullptr) last_axis->MarkWithAttr(AttrInfo{attr_key, la.first}); } @@ -354,7 +375,7 @@ void SpaceAnalyzer::IdentifyDmaUnderCondition() { }; ktvm::ir::PostOrderVisit(pe.cond->condition, DetectToT); if (!contain_tot) continue; - TileAxis *tot_axis = analyzer_->Axis(GetBufferLastAxis(pe.dst)); + TileAxis *tot_axis = analyzer_->Axis(GetBufferInnerAxis(pe.dst)); if (tot_axis != nullptr) tot_axis->MarkWithAttr(AttrInfo{attr_key, ""}); } } @@ -371,14 +392,14 @@ void SpaceAnalyzer::IdentifyAlignAxes() { std::vector src_tensors = pe.src; Tensor dst_tensor = pe.dst; if (pe.basic_op_type.find("TRANSPOSE") != std::string::npos) { - const For *dst_last = GetBufferLastAxis(dst_tensor); + const For *dst_last = GetBufferInnerAxis(dst_tensor); if (dst_last != nullptr) { align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type); } else { analyzer_->RootAxis()->MarkWithAttr(AttrInfo{"TRANSFORM", dst_tensor.name}); } } else if (pe.basic_op_type.find("DMA") != std::string::npos) { - const For *dst_last = GetBufferLastAxis(dst_tensor); + const For *dst_last = GetBufferInnerAxis(dst_tensor); if (dst_last != nullptr) { align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type); } else { @@ -394,7 +415,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { } for (auto t : src_tensors) { if (t.loops.size() <= dst_tensor.loops.size()) continue; - const For *src_last = GetBufferLastAxis(t); + const For *src_last = GetBufferInnerAxis(t); if (src_last != nullptr) { align_axes_attrs[src_last] = std::make_pair(t.name, pe.basic_op_type); } @@ -407,7 +428,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { const Tensor dst_tensor) { for (auto src : src_tensors) { if (src.name != dst_tensor.name) { - src_last = GetBufferLastAxis(src); + src_last = GetBufferInnerAxis(src); src_name = src.name; break; } @@ -416,7 +437,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { if (const auto i = src_last->extent.as()) gm_block = i->value; }; if (pe.basic_op_type.find("REDUCE") != std::string::npos) { - const For *dst_last = GetBufferLastAxis(dst_tensor); + const For *dst_last = GetBufferInnerAxis(dst_tensor); int64_t ub_block = 1; if (dst_last != nullptr) { align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type); @@ -430,7 +451,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { } } } else if (pe.basic_op_type.find("BROADCAST") != std::string::npos) { - const For *dst_last = GetBufferLastAxis(dst_tensor); + const For *dst_last = GetBufferInnerAxis(dst_tensor); int64_t ub_block = 1; if (dst_last == nullptr) continue; if (const auto i = dst_last->extent.as()) ub_block = i->value; @@ -451,7 +472,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() { } } -const For *SpaceAnalyzer::GetBufferLastAxis(Tensor t, int offset) { +const For *SpaceAnalyzer::GetBufferInnerAxis(Tensor t, int offset) { int last_dim = static_cast(t.var_names.size()) - offset; auto it = t.loops.find(last_dim); if (it != t.loops.end() && it->second.size() == 1U) return it->second[0]; @@ -470,11 +491,11 @@ void SpaceAnalyzer::IdentifyReduceAxes() { std::vector pes = it.second; for (auto pe : pes) { if ((pe.basic_op_type.find("REDUCE") == std::string::npos)) continue; - const For *dst_last = GetBufferLastAxis(pe.dst); + const For *dst_last = GetBufferInnerAxis(pe.dst); if (dst_last == nullptr) { // Reduce op like A[i, 0] = A[i, 0] op B[i, j], we need to mark axis `i` as dst last for dma align. for (auto offset = 0; offset < static_cast(pe.dst.var_names.size()); ++offset) { - dst_last = GetBufferLastAxis(pe.dst, offset + 1); + dst_last = GetBufferInnerAxis(pe.dst, offset + 1); if (dst_last == nullptr) continue; MarkAttr(dst_last, "REDUCE_DST_LAST", pe.dst.name); break; @@ -484,7 +505,7 @@ void SpaceAnalyzer::IdentifyReduceAxes() { } for (Tensor src : pe.src) { if (src.name == pe.dst.name) continue; - const For *src_last = GetBufferLastAxis(src); + const For *src_last = GetBufferInnerAxis(src); MarkAttr(src_last, "REDUCE_SRC_LAST", src.name); std::string flow = src.name + "->" + pe.dst.name; root->MarkWithAttr(AttrInfo{"REDUCE_FLOW", flow}); diff --git a/src/poly/space_analyzer.h b/src/poly/space_analyzer.h index d48a09daeee8a2b9678dcc3175107167902e5c21..471170fc583cb84d8b4de4feab679e99cfd9e41d 100644 --- a/src/poly/space_analyzer.h +++ b/src/poly/space_analyzer.h @@ -67,7 +67,7 @@ class SpaceAnalyzer { // Provides stmt after analysis. std::unordered_map> provides_ana_; - const For *GetBufferLastAxis(Tensor t, int offset = 1); + const For *GetBufferInnerAxis(Tensor t, int offset = 1); // generalized cases void IdentifyInsnType(); void IdentifyDmaUnderCondition(); diff --git a/src/poly/tiling_solver.cc b/src/poly/tiling_solver.cc index f3888aa319d1fd965210baf32ca331ba497f0fe3..c58cdf4e9259e99fe939708ebf6694207ebe380a 100644 --- a/src/poly/tiling_solver.cc +++ b/src/poly/tiling_solver.cc @@ -462,7 +462,9 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem ss << "--> Init factor " << final_factor; auto mod_value = cons.tile_mod_.as() ? cons.tile_mod_.as()->value : 1; - if (static_shape >= mod_value && final_factor % mod_value != 0) { + bool is_unaligned = (static_shape >= mod_value && final_factor % mod_value != 0); + bool need_to_align = (final_factor > mod_value || !axis->HasAttr("VECTORIZED")); + if (is_unaligned && need_to_align) { final_factor = std::max(static_cast(final_factor / mod_value * mod_value), 1); ss << "--> Mod value " << mod_value << " --> Align to mod " << final_factor; } @@ -810,12 +812,15 @@ bool TraverseSolver::IsTilable(TileInfo *info) { } if (level == LEVEL1) { - min_tile = cons.tile_mod_.as()->value; - - if ((info->axis->forbid_iso && const_extent % min_tile != 0) || (cons.tile_min_.as()->value > min_tile) || - (cons.tile_min_.as()->value == MIN_TILE)) { + bool use_tile_min = (info->axis->forbid_iso && const_extent % cons.tile_mod_.as()->value != 0) || + (cons.tile_min_.as()->value == MIN_TILE) || (axis->HasAttr("VECTORIZED")) || + (cons.tile_min_.as()->value > cons.tile_mod_.as()->value); + if (use_tile_min) { min_tile = cons.tile_min_.as()->value; + } else { + min_tile = cons.tile_mod_.as()->value; } + if (axis->range_min > min_tile) { min_tile = axis->range_min; } diff --git a/src/poly/tiling_strategy_manager.cc b/src/poly/tiling_strategy_manager.cc index 712a1799d8b18cbed9d8ab369cf934c45fda51ce..35877c6a997dae4ca2037226c025a9b457b4c7b7 100644 --- a/src/poly/tiling_strategy_manager.cc +++ b/src/poly/tiling_strategy_manager.cc @@ -222,20 +222,20 @@ void VectorizedStrategy::AddConstraint() { return; } for (auto axis : analyzer_->GetAxesOfAttr("VECTORIZED")) { - if (axis->HasAttr("DYNAMIC_BOUND")) { + if (axis->HasAttr("DYNAMIC_BOUND") || axis->range_extent.as() == nullptr) { continue; } int64_t min_byte = -1; - if (axis->data_size.empty()) { - min_byte = 1; - } else { - for (const auto &it : axis->data_size) { - if (min_byte == -1 || min_byte > it.second) { - min_byte = it.second; - } + for (const auto &it : axis->data_size) { + if (it.second == 0) { + continue; + } + if (min_byte == -1 || min_byte > it.second) { + min_byte = it.second; } } - CHECK_NE(min_byte, 0); + min_byte = min_byte == -1 ? 1 : min_byte; + CHECK_GT(min_byte, 0); axis->l1_constraints.tile_mod_ = CanonicalSimplify(CastIntToExpr(VECTORIZE_BYTE / min_byte)); } }