提交 8d9c11b3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!127 fix a bug that random test case segment_max failed

Merge pull request !127 from hanhuifeng/for_1_remove
...@@ -848,6 +848,9 @@ BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_ ...@@ -848,6 +848,9 @@ BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_
new_strides.Set(i, new_strides[i + 1] * new_dims[i + 1]); new_strides.Set(i, new_strides[i + 1] * new_dims[i + 1]);
} }
new_dims.Set(new_dims.size() - 1, simd_len); new_dims.Set(new_dims.size() - 1, simd_len);
} else if (simd_var->name_hint == un_def_var) {
new_dims.Set(var_idx, extent);
new_strides.Set(var_idx, last_dim_len);
} else { } else {
new_dims = {extent}; new_dims = {extent};
} }
...@@ -946,7 +949,7 @@ BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_ ...@@ -946,7 +949,7 @@ BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_
src_tmp_info1.GetNode()->shape_ = dst_tmp_info->shape_; src_tmp_info1.GetNode()->shape_ = dst_tmp_info->shape_;
src_tmp_info1.GetNode()->strides_ = dst_tmp_info->strides_; src_tmp_info1.GetNode()->strides_ = dst_tmp_info->strides_;
src_tmp_info1.GetNode()->var_ = dst_tmp_info->var_; src_tmp_info1.GetNode()->var_ = dst_tmp_info->var_;
src_tmp_info1.GetNode()->index_ = dst_tmp_info->index_; src_tmp_info1.GetNode()->index_ = 0;
} }
} }
......
...@@ -453,6 +453,25 @@ class RegCondition : public IRMutator { ...@@ -453,6 +453,25 @@ class RegCondition : public IRMutator {
int reg_cnt_{0}; int reg_cnt_{0};
}; };
class ForSimplify : public IRMutator {
public:
ForSimplify() = default;
~ForSimplify() override = default;
private:
Stmt Mutate_(const For *op, const Stmt &s) final {
if ((op->extent.as<IntImm>() && op->extent.as<IntImm>()->value == 1) ||
(op->extent.as<UIntImm>() && op->extent.as<UIntImm>()->value == 1)) {
Map<Var, Expr> var_map;
var_map.Set(op->loop_var, op->min);
auto body = Mutate(op->body);
body = Simplify(Substitute(body, var_map));
return body;
}
return IRMutator::Mutate_(op, s);
}
};
Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Map<Tensor, Buffer> &extern_buffer, Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Map<Tensor, Buffer> &extern_buffer,
bool is_dynamic) { bool is_dynamic) {
char *debug_var = getenv("DEBUG_MODE"); char *debug_var = getenv("DEBUG_MODE");
...@@ -463,6 +482,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma ...@@ -463,6 +482,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma
if (debug_mode) { if (debug_mode) {
stmt = EmitInsnDebug(stmt); stmt = EmitInsnDebug(stmt);
} }
stmt = ForSimplify().Mutate(stmt);
stmt = PreEmit().Mutate(stmt); stmt = PreEmit().Mutate(stmt);
if (!is_dynamic) { if (!is_dynamic) {
char *comment_var = getenv("COMMENT_LEVEL"); char *comment_var = getenv("COMMENT_LEVEL");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册