未验证 提交 ce57365d 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] temp fix batch_norm check as inplace op bug (#49738)

上级 30f5e39b
...@@ -52,6 +52,16 @@ using framework::ir::Node; ...@@ -52,6 +52,16 @@ using framework::ir::Node;
using GraphNodeVec = std::vector<Node*>; using GraphNodeVec = std::vector<Node*>;
using GraphNodeMap = std::unordered_map<Node*, Node*>; using GraphNodeMap = std::unordered_map<Node*, Node*>;
std::string GetDebugInfo(const std::unordered_set<std::string>& var_names) {
std::string debug_info = "[";
for (auto& var : var_names) {
debug_info.append(var);
debug_info.append(", ");
}
debug_info.append("]");
return debug_info;
}
OpTransInfo::OpTransInfo() { OpTransInfo::OpTransInfo() {
// judgment condition for the dynamic slice // judgment condition for the dynamic slice
dynamic_op_cond_.emplace("slice", [](const ir::Node& node) -> bool { dynamic_op_cond_.emplace("slice", [](const ir::Node& node) -> bool {
...@@ -115,16 +125,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -115,16 +125,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const GraphNodeSet& cluster) const { const GraphNodeSet& cluster) const {
std::unordered_set<std::string> deny_var_set; std::unordered_set<std::string> deny_var_set;
auto get_debug_info = [](const std::unordered_set<std::string>& var_names) {
std::string debug_info = "[";
for (auto& var : var_names) {
debug_info.append(var);
debug_info.append(", ");
}
debug_info.append("]");
return debug_info;
};
for (auto* op : cluster) { for (auto* op : cluster) {
if (deny_param_cond_.count(op->Name())) { if (deny_param_cond_.count(op->Name())) {
const auto* desc = op->Op(); const auto* desc = op->Op();
...@@ -136,7 +136,7 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -136,7 +136,7 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
op->Name().c_str())); op->Name().c_str()));
auto deny_param_names = deny_param_cond_.at(op->Name()); auto deny_param_names = deny_param_cond_.at(op->Name());
VLOG(4) << "We found deny param " << get_debug_info(deny_param_names) VLOG(4) << "We found deny param " << GetDebugInfo(deny_param_names)
<< " in op [" << op->Name() << "]."; << " in op [" << op->Name() << "].";
for (const auto& param_name : deny_param_names) { for (const auto& param_name : deny_param_names) {
...@@ -161,16 +161,51 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -161,16 +161,51 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
} }
} }
VLOG(4) << "All deny var names are " << get_debug_info(deny_var_set); VLOG(4) << "All deny var names are " << GetDebugInfo(deny_var_set);
return deny_var_set; return deny_var_set;
} }
bool OpTransInfo::IsInplaceOp(const OpDesc& op_desc) { std::unordered_set<std::string> OpTransInfo::GetIgnoreInplaceVarNames(
const OpDesc& op_desc) const {
if (!ignore_inplace_param_cond_.count(op_desc.Type())) {
return {};
}
const auto& ignore_inplace_names =
ignore_inplace_param_cond_.at(op_desc.Type());
VLOG(4) << "We found ignore inplace param "
<< GetDebugInfo(ignore_inplace_names) << " in op [" << op_desc.Type()
<< "].";
std::unordered_set<std::string> ignore_inplace_set;
for (const auto& param_name : ignore_inplace_names) {
if (op_desc.HasOutput(param_name)) {
const auto& arg_names = op_desc.Output(param_name);
ignore_inplace_set.insert(arg_names.begin(), arg_names.end());
}
}
VLOG(4) << "All ignore inplace var names are "
<< GetDebugInfo(ignore_inplace_set);
return ignore_inplace_set;
}
bool OpTransInfo::IsInplaceOp(
const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const {
const auto& ignore_inplace_set = GetIgnoreInplaceVarNames(op_desc);
auto inputs = op_desc.InputArgumentNames(); auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end()); std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) { for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0) return true; if (input_set.count(name) > 0 && !deny_var_names.count(name) &&
!ignore_inplace_set.count(name)) {
VLOG(4) << "The argument " << name << " in op " << op_desc.Type()
<< " is a inplace op, skip!";
return true;
}
} }
return false; return false;
} }
...@@ -630,8 +665,11 @@ void ReplaceSubGraphWithCinnOpNode( ...@@ -630,8 +665,11 @@ void ReplaceSubGraphWithCinnOpNode(
void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
OpTransInfo trans_info; OpTransInfo trans_info;
auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) { const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes());
auto teller = [&allow_ops, &deny_ops, &trans_info, &deny_var_set](
const Node* node) {
const auto& node_name = node->Name(); const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node_name) != nullptr; node_name) != nullptr;
...@@ -643,7 +681,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -643,7 +681,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
bool is_support = bool is_support =
registered && !trans_info.default_deny_ops().count(node_name) && registered && !trans_info.default_deny_ops().count(node_name) &&
!is_dynamic && (node->IsOp() && !trans_info.IsInplaceOp(*node->Op())); !is_dynamic &&
(node->IsOp() && !trans_info.IsInplaceOp(*node->Op(), deny_var_set));
// if the op type is registered in CINN and allow_ops is not empty, return // if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops // true only when it is in allow_ops
if (!allow_ops.empty()) { if (!allow_ops.empty()) {
...@@ -659,8 +698,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -659,8 +698,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
// return true only when it is registered in CINN // return true only when it is registered in CINN
return is_support; return is_support;
}; };
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The allowed Cinn Ops: " << GetDebugInfo(allow_ops);
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << GetDebugInfo(deny_ops);
std::vector<GraphNodeVec> clusters = CinnSubgraphDetector(graph, teller)(); std::vector<GraphNodeVec> clusters = CinnSubgraphDetector(graph, teller)();
LOG(INFO) << "--- [build_cinn_pass] detected " << clusters.size() LOG(INFO) << "--- [build_cinn_pass] detected " << clusters.size()
<< " cinn supported subgraphs"; << " cinn supported subgraphs";
......
...@@ -67,7 +67,11 @@ class OpTransInfo { ...@@ -67,7 +67,11 @@ class OpTransInfo {
std::unordered_set<std::string> GetDenyVarNames( std::unordered_set<std::string> GetDenyVarNames(
const GraphNodeSet& cluster) const; const GraphNodeSet& cluster) const;
static bool IsInplaceOp(const OpDesc& op_desc); std::unordered_set<std::string> GetIgnoreInplaceVarNames(
const OpDesc& op_desc) const;
bool IsInplaceOp(const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const;
private: private:
DyOpCondT dynamic_op_cond_; DyOpCondT dynamic_op_cond_;
...@@ -75,6 +79,9 @@ class OpTransInfo { ...@@ -75,6 +79,9 @@ class OpTransInfo {
DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}}, DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}}; {"batch_norm_grad", {"ReserveSpace"}}};
DeParamCondT ignore_inplace_param_cond_{
{"batch_norm", {"MeanOut", "VarianceOut"}}};
std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"}; std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"};
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册