diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index 5144543e05bad845389da6229617f880243cefad..1529471bf46f19220b506fb0b02783ef86f78a43 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -243,8 +243,31 @@ struct SimplifyIfThenElseMutator : public ir::IRMutator<> { auto* node = expr->As(); node->condition = common::AutoSimplify(node->condition); - if (node->true_case.defined()) Visit(&node->true_case, &node->true_case); - if (node->false_case.defined()) Visit(&node->false_case, &node->false_case); + auto* condition_int = node->condition.As(); + auto* condition_uint = node->condition.As(); + int64_t value; + if (condition_int || condition_uint) { + if (condition_int) { + value = condition_int->value; + } else { + value = condition_uint->value; + } + if (value) { + *expr = op->true_case; + } else { + if (op->false_case.defined()) { + *expr = op->false_case; + } else { + // null condition + *expr = ir::Block::Make({}); + } + } + } + if (expr->As()) { + if (node->true_case.defined()) Visit(&node->true_case, &node->true_case); + if (node->false_case.defined()) + Visit(&node->false_case, &node->false_case); + } } }; diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index d0e800a2cd4f969546bf26c69551e7f62004d994..257ef78b313424888206a661ee570fa2bee3c9a2 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -79,8 +79,9 @@ Expr Optimize(Expr e, Simplify(&copied); VLOG(10) << "After Optimize Simplify:" << copied; - IfSimplify(&copied); - VLOG(10) << "After Optimize IfSimplify:" << copied; + // TODO(LiuYang): I attends to remove this part code, I integate it into + // ifthenelse part IfSimplify(&copied); VLOG(10) << "After Optimize + // IfSimplify:" << copied; if (runtime_debug_info) { LOG(WARNING) << "Turn on runtime debug information output";