From edfc1ab8d50bf12ecf31366b60e00a39e2164e1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Thu, 17 Aug 2023 17:43:04 +0800 Subject: [PATCH] rewrite SimplifyForLoop (#56350) --- paddle/cinn/hlir/op/reduction.cc | 1 + paddle/cinn/optim/ir_simplify.cc | 9 +++------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/paddle/cinn/hlir/op/reduction.cc b/paddle/cinn/hlir/op/reduction.cc index 3cd4a3bdd79..a396aec315a 100644 --- a/paddle/cinn/hlir/op/reduction.cc +++ b/paddle/cinn/hlir/op/reduction.cc @@ -182,6 +182,7 @@ std::shared_ptr StrategyForReduce( // support the length-1 for loop yet. So we simplify here. The todo // is that remove SimplifyForLoops below and change reduction schedule optim::SimplifyForLoops(&temp); + optim::SimplifyBlocks(&temp); vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i]; diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index 8856b144488..3d187f8a413 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -339,12 +339,9 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> { std::string var_name = node->loop_var->name; var_intervals.emplace( var_name, common::CasInterval{min_i->value, extent_i->value - 1}); - if (node->body.As() && - node->body.As()->stmts.size() == 1) { - *expr = node->body.As()->stmts[0]; - } else { - *expr = node->body; - } + + *expr = node->body; + Visit(expr, expr); var_intervals.erase(var_name); } else { -- GitLab