未验证 提交 1653f99f 编写于 作者: J jiangcheng 提交者: GitHub

add deny param list to solve unuse param cannot found problem (#36996)

* add deny param list to solve unuse param cannot found the problem

* enclosure deny list in a function

* update by review advice
上级 bea0c9f5
...@@ -56,6 +56,63 @@ namespace { ...@@ -56,6 +56,63 @@ namespace {
// & FLAGS_deny_cinn_ops. // & FLAGS_deny_cinn_ops.
constexpr char kDelim[] = ";"; constexpr char kDelim[] = ";";
const std::unordered_map<std::string, std::unordered_set<std::string>>
kDenyParamMap = {{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
std::unordered_set<std::string> GetDenyVarNames(const GraphNodeSet& cluster) {
std::unordered_set<std::string> deny_var_set;
auto get_debug_info = [](const std::unordered_set<std::string>& var_names) {
std::string debug_info = "[";
for (auto& var : var_names) {
debug_info.append(var);
debug_info.append(", ");
}
debug_info.append("]");
return debug_info;
};
for (auto* op : cluster) {
if (kDenyParamMap.count(op->Name())) {
const auto* desc = op->Op();
PADDLE_ENFORCE_NE(desc, nullptr,
platform::errors::PreconditionNotMet(
"The Op %s's OpDesc should not be NULL, which has "
"a parameter in kDenyParamMap.",
op->Name().c_str()));
auto deny_param_names = kDenyParamMap.at(op->Name());
VLOG(4) << "We found deny param " << get_debug_info(deny_param_names)
<< " in op [" << op->Name() << "].";
for (const auto& param_name : deny_param_names) {
if (desc->Inputs().count(param_name)) {
const auto& arg_names = desc->Input(param_name);
for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
}
}
if (desc->HasOutput(param_name)) {
const auto& arg_names = desc->Output(param_name);
for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
}
}
}
}
}
VLOG(4) << "All deny var names are " << get_debug_info(deny_var_set);
return deny_var_set;
}
std::unordered_set<std::string> StringSplit(const std::string& str, std::unordered_set<std::string> StringSplit(const std::string& str,
const std::string& delim) { const std::string& delim) {
std::regex reg(delim); std::regex reg(delim);
...@@ -240,17 +297,24 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -240,17 +297,24 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
// out-graph should not using this node at all. // out-graph should not using this node at all.
// cluster_inputs & cluster_outputs & cluster_internals == NULL // cluster_inputs & cluster_outputs & cluster_internals == NULL
// cluster_outputs | cluster_internals == all graph op's outputs node // cluster_outputs | cluster_internals == all graph op's outputs node
void AnalyseClusterVariables(const GraphNodeSet& cluster, void AnalyseClusterVariables(
GraphNodeSet* cluster_inputs, const GraphNodeSet& cluster,
GraphNodeSet* cluster_outputs, const std::unordered_set<std::string>& deny_var_set,
GraphNodeSet* cluster_internals) { GraphNodeSet* cluster_inputs, GraphNodeSet* cluster_outputs,
GraphNodeSet* cluster_internals) {
// collecting all input and output of op // collecting all input and output of op
for (auto* op_node : cluster) { for (auto* op_node : cluster) {
const auto& op_name = op_node->Name();
for (auto* input_var_node : op_node->inputs) { for (auto* input_var_node : op_node->inputs) {
cluster_inputs->insert(input_var_node); if (!deny_var_set.count(input_var_node->Name())) {
// ignore deny var node
cluster_inputs->insert(input_var_node);
}
} }
for (auto* output_var_node : op_node->outputs) { for (auto* output_var_node : op_node->outputs) {
cluster_outputs->insert(output_var_node); if (!deny_var_set.count(output_var_node->Name())) {
cluster_outputs->insert(output_var_node);
}
} }
} }
// remove output node from cluster_inputs, // remove output node from cluster_inputs,
...@@ -294,22 +358,25 @@ void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs, ...@@ -294,22 +358,25 @@ void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs,
void AddCinnOpToGraph(const GraphNodeSet& cluster, void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_outputs,
const std::string& compilation_key, Graph* graph) { const std::string& compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) {
// Add the cinn launch op // Add the cinn launch op
framework::OpDesc cinn_op_desc; framework::OpDesc cinn_op_desc;
cinn_op_desc.SetType(kCinnLaunchOp); cinn_op_desc.SetType(kCinnLaunchOp);
std::vector<std::string> input_names; std::vector<std::string> input_names;
std::for_each(cluster_inputs.begin(), cluster_inputs.end(), std::for_each(cluster_inputs.begin(), cluster_inputs.end(),
[&input_names](Node* n) { [&input_names, &deny_var_set](Node* n) {
if (n->Var() != nullptr) { if (n->Var() != nullptr && !deny_var_set.count(n->Name())) {
input_names.emplace_back(n->Name()); input_names.emplace_back(n->Name());
} }
}); });
cinn_op_desc.SetInput("X", input_names); cinn_op_desc.SetInput("X", input_names);
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(), std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names](Node* n) { [&output_names, &deny_var_set](Node* n) {
if (n->Var() != nullptr) { if (n->Var() != nullptr && !deny_var_set.count(n->Name())) {
output_names.emplace_back(n->Name()); output_names.emplace_back(n->Name());
} }
}); });
...@@ -341,15 +408,14 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster, ...@@ -341,15 +408,14 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
// kCinnLaunchOp, and inputs ares cluster_inputs and outputs are // kCinnLaunchOp, and inputs ares cluster_inputs and outputs are
// cluster_outputs. // cluster_outputs.
// Meanwhile, move all links of cluster to the cinn op. // Meanwhile, move all links of cluster to the cinn op.
void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster, void ReplaceSubGraphWithCinnOpNode(
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_internals, const std::string& compilation_key,
const std::string& compilation_key, const std::unordered_set<std::string>& deny_var_set, Graph* graph) {
Graph* graph) {
// Add the cinn op node whose name is "kCinnLaunchOp" into graph // Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key, AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key,
graph); deny_var_set, graph);
// Remove the cinn subgraph from graph // Remove the cinn subgraph from graph
RemoveSubGraphFromGraph(cluster, cluster_internals, graph); RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
} }
...@@ -398,9 +464,11 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -398,9 +464,11 @@ void SearchAllSubgraphs(Graph* graph) {
// Classify var node to inputs, outputs, and internals. // Classify var node to inputs, outputs, and internals.
GraphNodeSet cluster_set(node_vec.begin(), node_vec.end()); GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());
auto deny_var_set = GetDenyVarNames(cluster_set);
GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals; GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
AnalyseClusterVariables(cluster_set, &cluster_inputs, &cluster_outputs, AnalyseClusterVariables(cluster_set, deny_var_set, &cluster_inputs,
&cluster_internals); &cluster_outputs, &cluster_internals);
VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set); VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs); VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
...@@ -417,7 +485,8 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -417,7 +485,8 @@ void SearchAllSubgraphs(Graph* graph) {
// Replace the found cluster to a new cinn op node // Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs, ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs,
cluster_internals, compilation_key, graph); cluster_internals, compilation_key,
deny_var_set, graph);
} }
} }
} // namespace } // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册