diff --git a/src/emit_insn/insn_args_calculator.cc b/src/emit_insn/insn_args_calculator.cc index 2e6a7144cd5a8c20b59b7b37826bbd6827e42581..a05af0ccc2af99fe6457d702a109e1beebac9b77 100644 --- a/src/emit_insn/insn_args_calculator.cc +++ b/src/emit_insn/insn_args_calculator.cc @@ -848,6 +848,9 @@ BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_ new_strides.Set(i, new_strides[i + 1] * new_dims[i + 1]); } new_dims.Set(new_dims.size() - 1, simd_len); + } else if (simd_var->name_hint == un_def_var) { + new_dims.Set(var_idx, extent); + new_strides.Set(var_idx, last_dim_len); } else { new_dims = {extent}; } @@ -946,7 +949,7 @@ BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_ src_tmp_info1.GetNode()->shape_ = dst_tmp_info->shape_; src_tmp_info1.GetNode()->strides_ = dst_tmp_info->strides_; src_tmp_info1.GetNode()->var_ = dst_tmp_info->var_; - src_tmp_info1.GetNode()->index_ = dst_tmp_info->index_; + src_tmp_info1.GetNode()->index_ = 0; } } diff --git a/src/pass/emit_insn.cc b/src/pass/emit_insn.cc index 38b90a85e12739fd2f459deca7c43225b63119b5..c8d7df0559d5c15c5d9d5f09f825ee6a68743d6e 100644 --- a/src/pass/emit_insn.cc +++ b/src/pass/emit_insn.cc @@ -453,6 +453,25 @@ class RegCondition : public IRMutator { int reg_cnt_{0}; }; +class ForSimplify : public IRMutator { + public: + ForSimplify() = default; + ~ForSimplify() override = default; + + private: + Stmt Mutate_(const For *op, const Stmt &s) final { + if ((op->extent.as() && op->extent.as()->value == 1) || + (op->extent.as() && op->extent.as()->value == 1)) { + Map var_map; + var_map.Set(op->loop_var, op->min); + auto body = Mutate(op->body); + body = Simplify(Substitute(body, var_map)); + return body; + } + return IRMutator::Mutate_(op, s); + } +}; + Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Map &extern_buffer, bool is_dynamic) { char *debug_var = getenv("DEBUG_MODE"); @@ -463,6 +482,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma if (debug_mode) { stmt = EmitInsnDebug(stmt); } + stmt = ForSimplify().Mutate(stmt); stmt = PreEmit().Mutate(stmt); if (!is_dynamic) { char *comment_var = getenv("COMMENT_LEVEL");