未验证 提交 d4ca7ffb 编写于 作者: Z Zhen Wang 提交者: GitHub

Add feed&fetch as default deny ops. (#44708)

上级 d0cf9d9d
...@@ -62,6 +62,8 @@ const std::unordered_map<std::string, std::unordered_set<std::string>> ...@@ -62,6 +62,8 @@ const std::unordered_map<std::string, std::unordered_set<std::string>>
kDenyParamMap = {{"batch_norm", {"ReserveSpace"}}, kDenyParamMap = {{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}}; {"batch_norm_grad", {"ReserveSpace"}}};
const std::unordered_set<std::string> kDefaultDenyOps = {"feed", "fetch"};
std::unordered_set<std::string> GetDenyVarNames(const GraphNodeSet& cluster) { std::unordered_set<std::string> GetDenyVarNames(const GraphNodeSet& cluster) {
std::unordered_set<std::string> deny_var_set; std::unordered_set<std::string> deny_var_set;
...@@ -560,22 +562,24 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -560,22 +562,24 @@ void SearchAllSubgraphs(Graph* graph) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
auto teller = [&allow_ops, &deny_ops](const Node* node) { auto teller = [&allow_ops, &deny_ops](const Node* node) {
const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node->Name()) != nullptr; node_name) != nullptr;
// if the op type is registered in CINN and allow_ops is not empty, return // if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops // true only when it is in allow_ops
if (allow_ops.size()) { if (!allow_ops.empty()) {
return registered && allow_ops.count(node->Name()); return registered && allow_ops.count(node_name);
} }
// if the op type is registered in CINN and deny_ops is not empty, return // if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops // true only when it is not in deny_ops
if (deny_ops.size()) { if (!deny_ops.empty()) {
return registered && !deny_ops.count(node->Name()); return registered && !deny_ops.count(node_name);
} }
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN // return true only when it is registered in CINN
return registered && (node->IsOp() && !IsInplaceOp(*node->Op())); return registered && !kDefaultDenyOps.count(node_name) &&
(node->IsOp() && !IsInplaceOp(*node->Op()));
}; };
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册