提交 b29b32d1 编写于 作者: D dabaiji

enhance vectorization analysis in auto tiling

上级 225139a0
......@@ -311,24 +311,45 @@ void SpaceAnalyzer::IdentifyInsnType() {
void SpaceAnalyzer::IdentifyVectorizedAxes() {
if (provides_ana_.empty()) return;
std::string attr_key = "VECTORIZED";
std::unordered_map<std::string, const For *> last_axes;
std::unordered_set<std::string> unsupported_insn = {"REDUCE", "TRANSFORM", "TRANSPOSE"};
std::unordered_map<std::string, const For *> mark_dst_axes;
for (auto it : provides_ana_) {
std::vector<ProvideEntry> pes = it.second;
for (auto pe : pes) {
if (pe.src.size() != 1U) continue;
bool skip = false;
for (auto insn : unsupported_insn) {
if (pe.basic_op_type.find(insn) != std::string::npos) {
skip = true;
break;
}
}
if (skip) {
continue;
}
Tensor dst_tensor = pe.dst;
Tensor src_tensor = pe.src[0];
if (dst_tensor.var_names.size() != 1U || src_tensor.var_names.size() != 1U) continue;
const For *dst_last = GetBufferLastAxis(dst_tensor);
const For *src_last = GetBufferLastAxis(src_tensor);
if (dst_last != nullptr && src_last != nullptr && dst_last == src_last) {
last_axes[dst_tensor.name] = dst_last;
const For *dst_last = GetBufferInnerAxis(dst_tensor);
// skip if dst is scalar
if (dst_last == nullptr) {
continue;
}
const For *src_last = nullptr;
for (auto src : pe.src) {
const auto *last = GetBufferInnerAxis(src);
if (last != nullptr && last == dst_last) {
src_last = last;
break;
}
}
// skip if src tensor does not share same inner-most axis with dst tensor
if (src_last == nullptr && !pe.src.empty()) {
continue;
}
mark_dst_axes[dst_tensor.name] = dst_last;
}
}
for (auto la : last_axes) {
for (auto la : mark_dst_axes) {
TileAxis *last_axis = analyzer_->Axis(la.second);
if (last_axis != nullptr) last_axis->MarkWithAttr(AttrInfo{attr_key, la.first});
}
......@@ -354,7 +375,7 @@ void SpaceAnalyzer::IdentifyDmaUnderCondition() {
};
ktvm::ir::PostOrderVisit(pe.cond->condition, DetectToT);
if (!contain_tot) continue;
TileAxis *tot_axis = analyzer_->Axis(GetBufferLastAxis(pe.dst));
TileAxis *tot_axis = analyzer_->Axis(GetBufferInnerAxis(pe.dst));
if (tot_axis != nullptr) tot_axis->MarkWithAttr(AttrInfo{attr_key, ""});
}
}
......@@ -371,14 +392,14 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
std::vector<Tensor> src_tensors = pe.src;
Tensor dst_tensor = pe.dst;
if (pe.basic_op_type.find("TRANSPOSE") != std::string::npos) {
const For *dst_last = GetBufferLastAxis(dst_tensor);
const For *dst_last = GetBufferInnerAxis(dst_tensor);
if (dst_last != nullptr) {
align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type);
} else {
analyzer_->RootAxis()->MarkWithAttr(AttrInfo{"TRANSFORM", dst_tensor.name});
}
} else if (pe.basic_op_type.find("DMA") != std::string::npos) {
const For *dst_last = GetBufferLastAxis(dst_tensor);
const For *dst_last = GetBufferInnerAxis(dst_tensor);
if (dst_last != nullptr) {
align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type);
} else {
......@@ -394,7 +415,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
}
for (auto t : src_tensors) {
if (t.loops.size() <= dst_tensor.loops.size()) continue;
const For *src_last = GetBufferLastAxis(t);
const For *src_last = GetBufferInnerAxis(t);
if (src_last != nullptr) {
align_axes_attrs[src_last] = std::make_pair(t.name, pe.basic_op_type);
}
......@@ -407,7 +428,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
const Tensor dst_tensor) {
for (auto src : src_tensors) {
if (src.name != dst_tensor.name) {
src_last = GetBufferLastAxis(src);
src_last = GetBufferInnerAxis(src);
src_name = src.name;
break;
}
......@@ -416,7 +437,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
if (const auto i = src_last->extent.as<IntImm>()) gm_block = i->value;
};
if (pe.basic_op_type.find("REDUCE") != std::string::npos) {
const For *dst_last = GetBufferLastAxis(dst_tensor);
const For *dst_last = GetBufferInnerAxis(dst_tensor);
int64_t ub_block = 1;
if (dst_last != nullptr) {
align_axes_attrs[dst_last] = std::make_pair(dst_tensor.name, pe.basic_op_type);
......@@ -430,7 +451,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
}
}
} else if (pe.basic_op_type.find("BROADCAST") != std::string::npos) {
const For *dst_last = GetBufferLastAxis(dst_tensor);
const For *dst_last = GetBufferInnerAxis(dst_tensor);
int64_t ub_block = 1;
if (dst_last == nullptr) continue;
if (const auto i = dst_last->extent.as<IntImm>()) ub_block = i->value;
......@@ -451,7 +472,7 @@ void SpaceAnalyzer::IdentifyAlignAxes() {
}
}
const For *SpaceAnalyzer::GetBufferLastAxis(Tensor t, int offset) {
const For *SpaceAnalyzer::GetBufferInnerAxis(Tensor t, int offset) {
int last_dim = static_cast<int>(t.var_names.size()) - offset;
auto it = t.loops.find(last_dim);
if (it != t.loops.end() && it->second.size() == 1U) return it->second[0];
......@@ -470,11 +491,11 @@ void SpaceAnalyzer::IdentifyReduceAxes() {
std::vector<ProvideEntry> pes = it.second;
for (auto pe : pes) {
if ((pe.basic_op_type.find("REDUCE") == std::string::npos)) continue;
const For *dst_last = GetBufferLastAxis(pe.dst);
const For *dst_last = GetBufferInnerAxis(pe.dst);
if (dst_last == nullptr) {
// Reduce op like A[i, 0] = A[i, 0] op B[i, j], we need to mark axis `i` as dst last for dma align.
for (auto offset = 0; offset < static_cast<int>(pe.dst.var_names.size()); ++offset) {
dst_last = GetBufferLastAxis(pe.dst, offset + 1);
dst_last = GetBufferInnerAxis(pe.dst, offset + 1);
if (dst_last == nullptr) continue;
MarkAttr(dst_last, "REDUCE_DST_LAST", pe.dst.name);
break;
......@@ -484,7 +505,7 @@ void SpaceAnalyzer::IdentifyReduceAxes() {
}
for (Tensor src : pe.src) {
if (src.name == pe.dst.name) continue;
const For *src_last = GetBufferLastAxis(src);
const For *src_last = GetBufferInnerAxis(src);
MarkAttr(src_last, "REDUCE_SRC_LAST", src.name);
std::string flow = src.name + "->" + pe.dst.name;
root->MarkWithAttr(AttrInfo{"REDUCE_FLOW", flow});
......
......@@ -67,7 +67,7 @@ class SpaceAnalyzer {
// Provides stmt after analysis.
std::unordered_map<const For *, std::vector<ProvideEntry>> provides_ana_;
const For *GetBufferLastAxis(Tensor t, int offset = 1);
const For *GetBufferInnerAxis(Tensor t, int offset = 1);
// generalized cases
void IdentifyInsnType();
void IdentifyDmaUnderCondition();
......
......@@ -458,7 +458,9 @@ int64_t InequalitySolver::DetermineTileForStatic(TileAxis *axis, const Expr &mem
ss << "--> Init factor " << final_factor;
auto mod_value = cons.tile_mod_.as<IntImm>() ? cons.tile_mod_.as<IntImm>()->value : 1;
if (static_shape >= mod_value && final_factor % mod_value != 0) {
bool is_unaligned = (static_shape >= mod_value && final_factor % mod_value != 0);
bool need_to_align = (final_factor > mod_value || !axis->HasAttr("VECTORIZED"));
if (is_unaligned && need_to_align) {
final_factor = std::max(static_cast<int>(final_factor / mod_value * mod_value), 1);
ss << "--> Mod value " << mod_value << " --> Align to mod " << final_factor;
}
......@@ -804,12 +806,15 @@ bool TraverseSolver::IsTilable(TileInfo *info) {
}
if (level == LEVEL1) {
min_tile = cons.tile_mod_.as<IntImm>()->value;
if ((info->axis->forbid_iso && const_extent % min_tile != 0) || (cons.tile_min_.as<IntImm>()->value > min_tile) ||
(cons.tile_min_.as<IntImm>()->value == MIN_TILE)) {
bool use_tile_min = (info->axis->forbid_iso && const_extent % cons.tile_mod_.as<IntImm>()->value != 0) ||
(cons.tile_min_.as<IntImm>()->value == MIN_TILE) || (axis->HasAttr("VECTORIZED")) ||
(cons.tile_min_.as<IntImm>()->value > cons.tile_mod_.as<IntImm>()->value);
if (use_tile_min) {
min_tile = cons.tile_min_.as<IntImm>()->value;
} else {
min_tile = cons.tile_mod_.as<IntImm>()->value;
}
if (axis->range_min > min_tile) {
min_tile = axis->range_min;
}
......
......@@ -232,20 +232,20 @@ void VectorizedStrategy::AddConstraint() {
return;
}
for (auto axis : analyzer_->GetAxesOfAttr("VECTORIZED")) {
if (axis->HasAttr("DYNAMIC_BOUND")) {
if (axis->HasAttr("DYNAMIC_BOUND") || axis->range_extent.as<IntImm>() == nullptr) {
continue;
}
int64_t min_byte = -1;
if (axis->data_size.empty()) {
min_byte = 1;
} else {
for (const auto &it : axis->data_size) {
if (min_byte == -1 || min_byte > it.second) {
min_byte = it.second;
}
for (const auto &it : axis->data_size) {
if (it.second == 0) {
continue;
}
if (min_byte == -1 || min_byte > it.second) {
min_byte = it.second;
}
}
CHECK_NE(min_byte, 0);
min_byte = min_byte == -1 ? 1 : min_byte;
CHECK_GT(min_byte, 0);
axis->l1_constraints.tile_mod_ = CanonicalSimplify(CastIntToExpr(VECTORIZE_BYTE / min_byte));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册