diff --git a/paddle/cinn/hlir/op/reduction.cc b/paddle/cinn/hlir/op/reduction.cc index 3cd4a3bdd79dd8a25546dffc476741532ed0881a..a396aec315af46d578978fb3a2a9aae38ac5b34f 100644 --- a/paddle/cinn/hlir/op/reduction.cc +++ b/paddle/cinn/hlir/op/reduction.cc @@ -182,6 +182,7 @@ std::shared_ptr StrategyForReduce( // support the length-1 for loop yet. So we simplify here. The todo // is that remove SimplifyForLoops below and change reduction schedule optim::SimplifyForLoops(&temp); + optim::SimplifyBlocks(&temp); vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i]; diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index 8856b144488b65c61b93ff79295a957af2fb5274..3d187f8a4138ce0622a0e0c79e8af76ac0d90f7a 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -339,12 +339,9 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> { std::string var_name = node->loop_var->name; var_intervals.emplace( var_name, common::CasInterval{min_i->value, extent_i->value - 1}); - if (node->body.As() && - node->body.As()->stmts.size() == 1) { - *expr = node->body.As()->stmts[0]; - } else { - *expr = node->body; - } + + *expr = node->body; + Visit(expr, expr); var_intervals.erase(var_name); } else {