From ed6b9567893e9cfc395b91fde4c6865ed7b6bffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Fri, 7 Jul 2023 14:53:21 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90CINN=E3=80=91refactor=20ir=5Fvisitor?= =?UTF-8?q?=20(#55171)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR delete middle ir_visitor class and thus we can avoid middle virtual function call and codes look more clean pcard-72718 --- .../cost_model/feature_extractor.cc | 4 +- .../cost_model/feature_extractor.h | 2 +- .../search_space/search_state.cc | 4 +- paddle/cinn/backends/llvm/codegen_llvm.h | 4 +- paddle/cinn/backends/modular.cc | 4 +- paddle/cinn/ir/collect_ir_nodes.cc | 2 +- paddle/cinn/ir/ir_compare.cc | 2 +- paddle/cinn/ir/ir_compare.h | 2 +- paddle/cinn/ir/ir_mutator.h | 156 ++++++++++-------- paddle/cinn/ir/ir_printer.cc | 2 +- paddle/cinn/ir/ir_printer.h | 2 +- paddle/cinn/ir/ir_visitor.h | 15 +- paddle/cinn/optim/insert_debug_log_callee.cc | 6 +- paddle/cinn/optim/ir_copy.cc | 24 +-- paddle/cinn/pybind/ir.cc | 1 + 15 files changed, 126 insertions(+), 104 deletions(-) mode change 100755 => 100644 paddle/cinn/ir/ir_mutator.h diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc b/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc index 78a28164f3c..ba0e6239bee 100644 --- a/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc @@ -45,7 +45,9 @@ using namespace ::cinn::ir; // NOLINT FeatureExtractor::FeatureExtractor() {} -void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); } +void FeatureExtractor::Visit(const Expr *x) { + IRVisitorRequireReImpl::Visit(x); +} Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, const common::Target &target) { diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor.h b/paddle/cinn/auto_schedule/cost_model/feature_extractor.h index 9f3d3762eb6..c6109f658e1 100644 --- a/paddle/cinn/auto_schedule/cost_model/feature_extractor.h +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor.h @@ -37,7 +37,7 @@ namespace cinn { namespace auto_schedule { -class FeatureExtractor : public ir::IRVisitor { +class FeatureExtractor : public ir::IRVisitorRequireReImpl { public: FeatureExtractor(); Feature Extract(const ir::ModuleExpr& mod_expr, const common::Target& target); diff --git a/paddle/cinn/auto_schedule/search_space/search_state.cc b/paddle/cinn/auto_schedule/search_space/search_state.cc index 973270f493e..852ea25259c 100644 --- a/paddle/cinn/auto_schedule/search_space/search_state.cc +++ b/paddle/cinn/auto_schedule/search_space/search_state.cc @@ -71,7 +71,7 @@ bool operator<(const SearchState& left, const SearchState& right) { } // Visit every node by expanding all of their fields in dfs order -class DfsWithExprsFields : public ir::IRVisitor { +class DfsWithExprsFields : public ir::IRVisitorRequireReImpl { protected: #define __m(t__) \ void Visit(const ir::t__* x) override { \ @@ -85,7 +85,7 @@ class DfsWithExprsFields : public ir::IRVisitor { NODETY_FORALL(__m) #undef __m - void Visit(const Expr* expr) override { IRVisitor::Visit(expr); } + void Visit(const Expr* expr) override { IRVisitorRequireReImpl::Visit(expr); } }; // Generate a reduce hash of a AST tree by combining hash of each AST node diff --git a/paddle/cinn/backends/llvm/codegen_llvm.h b/paddle/cinn/backends/llvm/codegen_llvm.h index bf5be73adcb..ff885db2c8e 100644 --- a/paddle/cinn/backends/llvm/codegen_llvm.h +++ b/paddle/cinn/backends/llvm/codegen_llvm.h @@ -39,11 +39,11 @@ namespace cinn { namespace backends { -class LLVMIRVisitor : public ir::IRVisitorBase { +class LLVMIRVisitor : public ir::IRVisitorRequireReImpl { public: LLVMIRVisitor() = default; - using ir::IRVisitorBase::Visit; + using ir::IRVisitorRequireReImpl::Visit; #define __m(t__) virtual llvm::Value *Visit(const ir::t__ *x) = 0; NODETY_FORALL(__m) #undef __m diff --git a/paddle/cinn/backends/modular.cc b/paddle/cinn/backends/modular.cc index 41a74643d66..fb736154c7b 100644 --- a/paddle/cinn/backends/modular.cc +++ b/paddle/cinn/backends/modular.cc @@ -19,13 +19,13 @@ namespace cinn { namespace backends { -class ModularEvaluator : public ir::IRVisitorBase { +class ModularEvaluator : public ir::IRVisitorRequireReImpl { public: explicit ModularEvaluator(const std::map& mod_map) : mod_map_(mod_map) {} ModularEntry Eval(const Expr& e) { - return ir::IRVisitorBase::Visit(&e); + return ir::IRVisitorRequireReImpl::Visit(&e); } ModularEntry Visit(const ir::IntImm* op) { diff --git a/paddle/cinn/ir/collect_ir_nodes.cc b/paddle/cinn/ir/collect_ir_nodes.cc index 4c00ac975cc..74a13c2e61b 100644 --- a/paddle/cinn/ir/collect_ir_nodes.cc +++ b/paddle/cinn/ir/collect_ir_nodes.cc @@ -24,7 +24,7 @@ namespace ir { namespace { -struct IrNodesCollector : public IRVisitor { +struct IrNodesCollector : public IRVisitorRequireReImpl { using teller_t = std::function; using handler_t = std::function; diff --git a/paddle/cinn/ir/ir_compare.cc b/paddle/cinn/ir/ir_compare.cc index 3b9f357d546..9832343c5ff 100644 --- a/paddle/cinn/ir/ir_compare.cc +++ b/paddle/cinn/ir/ir_compare.cc @@ -32,7 +32,7 @@ bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) { VLOG(5) << "Not equal on Expr, someone not defined"; } bool equal = lhs->node_type() == rhs->node_type(); - equal = equal && IRVisitorBase::Visit(&lhs, &rhs); + equal = equal && IRVisitorRequireReImpl::Visit(&lhs, &rhs); if (!equal) { VLOG(5) << "Not equal on Expr, lhs:[type:" diff --git a/paddle/cinn/ir/ir_compare.h b/paddle/cinn/ir/ir_compare.h index 75e9bcf2dcc..9ea6a13c79d 100644 --- a/paddle/cinn/ir/ir_compare.h +++ b/paddle/cinn/ir/ir_compare.h @@ -23,7 +23,7 @@ namespace ir { // Determine whether two ir AST trees are euqal by comparing their struct and // fields of each node through dfs visitor -class IrEqualVisitor : public IRVisitorBase { +class IrEqualVisitor : public IRVisitorRequireReImpl { public: explicit IrEqualVisitor(bool allow_name_suffix_diff = false) : allow_name_suffix_diff_(allow_name_suffix_diff) {} diff --git a/paddle/cinn/ir/ir_mutator.h b/paddle/cinn/ir/ir_mutator.h old mode 100755 new mode 100644 index 9a7fa33756f..d83c5415220 --- a/paddle/cinn/ir/ir_mutator.h +++ b/paddle/cinn/ir/ir_mutator.h @@ -26,7 +26,7 @@ namespace ir { //! T might be Expr* or const Expr* template -class IRMutator : public IRVisitorBase { +class IRMutator : public IRVisitorRequireReImpl { public: void Visit(const Expr *expr, T op) override; @@ -37,22 +37,22 @@ class IRMutator : public IRVisitorBase { template void IRMutator::Visit(const Expr *expr, T op) { - IRVisitorBase::Visit(expr, op); + IRVisitorRequireReImpl::Visit(expr, op); } -#define UNARY_OP_IMPL(op__) \ - template \ - void IRMutator::Visit(const op__ *expr, T op) { \ - auto *node = op->template As(); \ - IRVisitorBase::Visit(&node->v(), &node->v()); \ +#define UNARY_OP_IMPL(op__) \ + template \ + void IRMutator::Visit(const op__ *expr, T op) { \ + auto *node = op->template As(); \ + IRVisitorRequireReImpl::Visit(&node->v(), &node->v()); \ } -#define BINARY_OP_IMPL(op__) \ - template \ - void IRMutator::Visit(const op__ *expr, T op) { \ - auto *node = op->template As(); \ - IRVisitorBase::Visit(&node->a(), &node->a()); \ - IRVisitorBase::Visit(&node->b(), &node->b()); \ +#define BINARY_OP_IMPL(op__) \ + template \ + void IRMutator::Visit(const op__ *expr, T op) { \ + auto *node = op->template As(); \ + IRVisitorRequireReImpl::Visit(&node->a(), &node->a()); \ + IRVisitorRequireReImpl::Visit(&node->b(), &node->b()); \ } NODETY_UNARY_OP_FOR_EACH(UNARY_OP_IMPL) @@ -77,172 +77,181 @@ void IRMutator::Visit(const Cast *expr, T op) { template void IRMutator::Visit(const For *expr, T op) { auto *node = op->template As(); - IRVisitorBase::Visit(&node->min, &node->min); - IRVisitorBase::Visit(&node->extent, &node->extent); - IRVisitorBase::Visit(&node->body, &node->body); + IRVisitorRequireReImpl::Visit(&node->min, &node->min); + IRVisitorRequireReImpl::Visit(&node->extent, &node->extent); + IRVisitorRequireReImpl::Visit(&node->body, &node->body); } template void IRMutator::Visit(const PolyFor *expr, T op) { auto *node = op->template As(); - // IRVisitorBase::Visit(&node->iterator, &node->iterator); - IRVisitorBase::Visit(&node->init, &node->init); - IRVisitorBase::Visit(&node->condition, &node->condition); - IRVisitorBase::Visit(&node->inc, &node->inc); - IRVisitorBase::Visit(&node->body, &node->body); + // IRVisitorRequireReImpl::Visit(&node->iterator, + // &node->iterator); + IRVisitorRequireReImpl::Visit(&node->init, &node->init); + IRVisitorRequireReImpl::Visit(&node->condition, &node->condition); + IRVisitorRequireReImpl::Visit(&node->inc, &node->inc); + IRVisitorRequireReImpl::Visit(&node->body, &node->body); } template void IRMutator::Visit(const Select *expr, T op) { auto *node = op->template As