From d4ca7ffbdf77232c75a96aa9ae396e7b33278cdb Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Fri, 5 Aug 2022 19:02:43 +0800 Subject: [PATCH] Add feed&fetch as default deny ops. (#44708) --- .../framework/paddle2cinn/build_cinn_pass.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 59364616494..f0f35ea28cc 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -62,6 +62,8 @@ const std::unordered_map> kDenyParamMap = {{"batch_norm", {"ReserveSpace"}}, {"batch_norm_grad", {"ReserveSpace"}}}; +const std::unordered_set kDefaultDenyOps = {"feed", "fetch"}; + std::unordered_set GetDenyVarNames(const GraphNodeSet& cluster) { std::unordered_set deny_var_set; @@ -560,22 +562,24 @@ void SearchAllSubgraphs(Graph* graph) { auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); auto teller = [&allow_ops, &deny_ops](const Node* node) { + const auto& node_name = node->Name(); bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( - node->Name()) != nullptr; + node_name) != nullptr; // if the op type is registered in CINN and allow_ops is not empty, return // true only when it is in allow_ops - if (allow_ops.size()) { - return registered && allow_ops.count(node->Name()); + if (!allow_ops.empty()) { + return registered && allow_ops.count(node_name); } // if the op type is registered in CINN and deny_ops is not empty, return // true only when it is not in deny_ops - if (deny_ops.size()) { - return registered && !deny_ops.count(node->Name()); + if (!deny_ops.empty()) { + return registered && !deny_ops.count(node_name); } // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, // return true only when it is registered in CINN - return registered && (node->IsOp() && !IsInplaceOp(*node->Op())); + return registered && !kDefaultDenyOps.count(node_name) && + (node->IsOp() && !IsInplaceOp(*node->Op())); }; VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; -- GitLab