From 4526f61ef922e81151bf9fcc917266b6d5610e08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 15 Aug 2023 14:24:03 +0800 Subject: [PATCH] refactor ifthenelse optimize (#56274) --- paddle/cinn/optim/ir_simplify.cc | 27 +++++++++++++++++++++++++-- paddle/cinn/optim/optimize.cc | 5 +++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index 5144543e05b..1529471bf46 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -243,8 +243,31 @@ struct SimplifyIfThenElseMutator : public ir::IRMutator<> { auto* node = expr->As(); node->condition = common::AutoSimplify(node->condition); - if (node->true_case.defined()) Visit(&node->true_case, &node->true_case); - if (node->false_case.defined()) Visit(&node->false_case, &node->false_case); + auto* condition_int = node->condition.As(); + auto* condition_uint = node->condition.As(); + int64_t value; + if (condition_int || condition_uint) { + if (condition_int) { + value = condition_int->value; + } else { + value = condition_uint->value; + } + if (value) { + *expr = op->true_case; + } else { + if (op->false_case.defined()) { + *expr = op->false_case; + } else { + // null condition + *expr = ir::Block::Make({}); + } + } + } + if (expr->As()) { + if (node->true_case.defined()) Visit(&node->true_case, &node->true_case); + if (node->false_case.defined()) + Visit(&node->false_case, &node->false_case); + } } }; diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index d0e800a2cd4..257ef78b313 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -79,8 +79,9 @@ Expr Optimize(Expr e, Simplify(&copied); VLOG(10) << "After Optimize Simplify:" << copied; - IfSimplify(&copied); - VLOG(10) << "After Optimize IfSimplify:" << copied; + // TODO(LiuYang): I attends to remove this part code, I integate it into + // ifthenelse part IfSimplify(&copied); VLOG(10) << "After Optimize + // IfSimplify:" << copied; if (runtime_debug_info) { LOG(WARNING) << "Turn on runtime debug information output"; -- GitLab