未验证 提交 c170074d 编写于 作者: 傅剑寒 提交者: GitHub

rename StoreButMutator and fix RampMutator (#56966)

上级 0900a790
......@@ -53,9 +53,9 @@ void PartialSimplify(
}
//! Simplify the expression but Load.
struct SimplifyButStoreLoadMutator : public ir::IRMutator<ir::Expr*> {
struct SimplifyNoPureMathMutator : public ir::IRMutator<ir::Expr*> {
common::cas_intervals_t& var_intervals;
explicit SimplifyButStoreLoadMutator(
explicit SimplifyNoPureMathMutator(
common::cas_intervals_t& var_intervals) // NOLINT
: var_intervals(var_intervals) {}
......@@ -76,19 +76,6 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator<ir::Expr*> {
__(Max)
#undef __
void Visit(const Ramp* op, Expr* expr) override {
auto* node = expr->As<Ramp>();
CHECK(common::IsPureMath(node->base));
CHECK(common::IsPureMath(node->stride));
PartialSimplify(&node->base, var_intervals);
PartialSimplify(&node->stride, var_intervals);
}
void Visit(const Cast* op, Expr* expr) override {
auto* node = expr->As<Cast>();
Visit(&node->v(), &node->v());
}
void Visit(const PolyFor* op, Expr* expr) override {
auto* node = expr->As<ir::PolyFor>();
node->condition = common::SolveInequality(op->condition, op->iterator);
......@@ -138,7 +125,7 @@ struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
if (common::IsPureMath(idx)) {
PartialSimplify(&idx, var_intervals_);
} else {
SimplifyButStoreLoadMutator mutator(var_intervals_);
SimplifyNoPureMathMutator mutator(var_intervals_);
mutator(&idx);
}
}
......@@ -176,7 +163,7 @@ struct SimplifyStoreMutator : public ir::IRMutator<ir::Expr*> {
if (common::IsPureMath(idx)) {
PartialSimplify(&idx, var_intervals_);
} else {
SimplifyButStoreLoadMutator mutator(var_intervals_);
SimplifyNoPureMathMutator mutator(var_intervals_);
mutator(&idx);
}
}
......@@ -215,8 +202,8 @@ struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
CHECK(common::IsPureMath(node->stride))
<< node->stride << "is not a pure math!";
Simplify(&node->base);
Simplify(&node->stride);
PartialSimplify(&node->base);
PartialSimplify(&node->stride);
}
// ramp + ramp
void Visit(const Add* op, Expr* expr) override {
......@@ -370,7 +357,7 @@ void Simplify(Expr* expr) {
SimplifyIfThenElseMutator()(expr);
common::cas_intervals_t var_intervals;
SimplifyButStoreLoadMutator mutator(var_intervals);
SimplifyNoPureMathMutator mutator(var_intervals);
mutator(expr);
ReplaceFracWithDivMutator()(expr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册