未验证 提交 d4ca7ffb 编写于 作者: Z Zhen Wang 提交者: GitHub

Add feed&fetch as default deny ops. (#44708)

上级 d0cf9d9d
......@@ -62,6 +62,8 @@ const std::unordered_map<std::string, std::unordered_set<std::string>>
kDenyParamMap = {{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
const std::unordered_set<std::string> kDefaultDenyOps = {"feed", "fetch"};
std::unordered_set<std::string> GetDenyVarNames(const GraphNodeSet& cluster) {
std::unordered_set<std::string> deny_var_set;
......@@ -560,22 +562,24 @@ void SearchAllSubgraphs(Graph* graph) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
auto teller = [&allow_ops, &deny_ops](const Node* node) {
const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node->Name()) != nullptr;
node_name) != nullptr;
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (allow_ops.size()) {
return registered && allow_ops.count(node->Name());
if (!allow_ops.empty()) {
return registered && allow_ops.count(node_name);
}
// if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops
if (deny_ops.size()) {
return registered && !deny_ops.count(node->Name());
if (!deny_ops.empty()) {
return registered && !deny_ops.count(node_name);
}
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN
return registered && (node->IsOp() && !IsInplaceOp(*node->Op()));
return registered && !kDefaultDenyOps.count(node_name) &&
(node->IsOp() && !IsInplaceOp(*node->Op()));
};
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册