提交 01ee1998 编写于 作者: H hanhuifeng2020

fix a bug that random test case segment_max failed

上级 fd902f71
......@@ -848,6 +848,9 @@ BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_
new_strides.Set(i, new_strides[i + 1] * new_dims[i + 1]);
}
new_dims.Set(new_dims.size() - 1, simd_len);
} else if (simd_var->name_hint == un_def_var) {
new_dims.Set(var_idx, extent);
new_strides.Set(var_idx, last_dim_len);
} else {
new_dims = {extent};
}
......@@ -946,7 +949,7 @@ BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_
src_tmp_info1.GetNode()->shape_ = dst_tmp_info->shape_;
src_tmp_info1.GetNode()->strides_ = dst_tmp_info->strides_;
src_tmp_info1.GetNode()->var_ = dst_tmp_info->var_;
src_tmp_info1.GetNode()->index_ = dst_tmp_info->index_;
src_tmp_info1.GetNode()->index_ = 0;
}
}
......
......@@ -453,6 +453,25 @@ class RegCondition : public IRMutator {
int reg_cnt_{0};
};
class ForSimplify : public IRMutator {
public:
ForSimplify() = default;
~ForSimplify() override = default;
private:
Stmt Mutate_(const For *op, const Stmt &s) final {
if ((op->extent.as<IntImm>() && op->extent.as<IntImm>()->value == 1) ||
(op->extent.as<UIntImm>() && op->extent.as<UIntImm>()->value == 1)) {
Map<Var, Expr> var_map;
var_map.Set(op->loop_var, op->min);
auto body = Mutate(op->body);
body = Simplify(Substitute(body, var_map));
return body;
}
return IRMutator::Mutate_(op, s);
}
};
Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Map<Tensor, Buffer> &extern_buffer,
bool is_dynamic) {
char *debug_var = getenv("DEBUG_MODE");
......@@ -463,6 +482,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma
if (debug_mode) {
stmt = EmitInsnDebug(stmt);
}
stmt = ForSimplify().Mutate(stmt);
stmt = PreEmit().Mutate(stmt);
if (!is_dynamic) {
char *comment_var = getenv("COMMENT_LEVEL");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册