From c170074df0137ed578553b3ce6095308407c3fed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Wed, 6 Sep 2023 16:11:29 +0800 Subject: [PATCH] rename StoreButMutator and fix RampMutator (#56966) --- paddle/cinn/optim/ir_simplify.cc | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index 3d187f8a413..3c0f9298dae 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -53,9 +53,9 @@ void PartialSimplify( } //! Simplify the expression but Load. -struct SimplifyButStoreLoadMutator : public ir::IRMutator { +struct SimplifyNoPureMathMutator : public ir::IRMutator { common::cas_intervals_t& var_intervals; - explicit SimplifyButStoreLoadMutator( + explicit SimplifyNoPureMathMutator( common::cas_intervals_t& var_intervals) // NOLINT : var_intervals(var_intervals) {} @@ -76,19 +76,6 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator { __(Max) #undef __ - void Visit(const Ramp* op, Expr* expr) override { - auto* node = expr->As(); - CHECK(common::IsPureMath(node->base)); - CHECK(common::IsPureMath(node->stride)); - PartialSimplify(&node->base, var_intervals); - PartialSimplify(&node->stride, var_intervals); - } - - void Visit(const Cast* op, Expr* expr) override { - auto* node = expr->As(); - Visit(&node->v(), &node->v()); - } - void Visit(const PolyFor* op, Expr* expr) override { auto* node = expr->As(); node->condition = common::SolveInequality(op->condition, op->iterator); @@ -138,7 +125,7 @@ struct SimplifyLoadMutator : public ir::IRMutator { if (common::IsPureMath(idx)) { PartialSimplify(&idx, var_intervals_); } else { - SimplifyButStoreLoadMutator mutator(var_intervals_); + SimplifyNoPureMathMutator mutator(var_intervals_); mutator(&idx); } } @@ -176,7 +163,7 @@ struct SimplifyStoreMutator : public ir::IRMutator { if (common::IsPureMath(idx)) { PartialSimplify(&idx, var_intervals_); } else { - SimplifyButStoreLoadMutator mutator(var_intervals_); + SimplifyNoPureMathMutator mutator(var_intervals_); mutator(&idx); } } @@ -215,8 +202,8 @@ struct SimplifyRampMutator : public ir::IRMutator { CHECK(common::IsPureMath(node->stride)) << node->stride << "is not a pure math!"; - Simplify(&node->base); - Simplify(&node->stride); + PartialSimplify(&node->base); + PartialSimplify(&node->stride); } // ramp + ramp void Visit(const Add* op, Expr* expr) override { @@ -370,7 +357,7 @@ void Simplify(Expr* expr) { SimplifyIfThenElseMutator()(expr); common::cas_intervals_t var_intervals; - SimplifyButStoreLoadMutator mutator(var_intervals); + SimplifyNoPureMathMutator mutator(var_intervals); mutator(expr); ReplaceFracWithDivMutator()(expr); -- GitLab