diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index b90dbd7dcd845e8785e7f504f41ce5a81c97046d..0cff68c41eb1012e0e9a6ac6cd288042fdedb8c5 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -56,6 +56,63 @@ namespace { // & FLAGS_deny_cinn_ops. constexpr char kDelim[] = ";"; +const std::unordered_map> + kDenyParamMap = {{"batch_norm", {"ReserveSpace"}}, + {"batch_norm_grad", {"ReserveSpace"}}}; + +std::unordered_set GetDenyVarNames(const GraphNodeSet& cluster) { + std::unordered_set deny_var_set; + + auto get_debug_info = [](const std::unordered_set& var_names) { + std::string debug_info = "["; + for (auto& var : var_names) { + debug_info.append(var); + debug_info.append(", "); + } + debug_info.append("]"); + return debug_info; + }; + + for (auto* op : cluster) { + if (kDenyParamMap.count(op->Name())) { + const auto* desc = op->Op(); + PADDLE_ENFORCE_NE(desc, nullptr, + platform::errors::PreconditionNotMet( + "The Op %s's OpDesc should not be NULL, which has " + "a parameter in kDenyParamMap.", + op->Name().c_str())); + + auto deny_param_names = kDenyParamMap.at(op->Name()); + VLOG(4) << "We found deny param " << get_debug_info(deny_param_names) + << " in op [" << op->Name() << "]."; + + for (const auto& param_name : deny_param_names) { + if (desc->Inputs().count(param_name)) { + const auto& arg_names = desc->Input(param_name); + for (const auto& arg_name : arg_names) { + deny_var_set.insert(arg_name); + VLOG(4) << "deny param [" << param_name << "]'s argument name" + << " is [" << arg_name << "]."; + } + } + + if (desc->HasOutput(param_name)) { + const auto& arg_names = desc->Output(param_name); + for (const auto& arg_name : arg_names) { + deny_var_set.insert(arg_name); + VLOG(4) << "deny param [" << param_name << "]'s argument name" + << " is [" << arg_name << "]."; + } + } + } + } + } + + VLOG(4) << "All deny var names are " << get_debug_info(deny_var_set); + + return deny_var_set; +} + std::unordered_set StringSplit(const std::string& str, const std::string& delim) { std::regex reg(delim); @@ -240,17 +297,24 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, // out-graph should not using this node at all. // cluster_inputs & cluster_outputs & cluster_internals == NULL // cluster_outputs | cluster_internals == all graph op's outputs node -void AnalyseClusterVariables(const GraphNodeSet& cluster, - GraphNodeSet* cluster_inputs, - GraphNodeSet* cluster_outputs, - GraphNodeSet* cluster_internals) { +void AnalyseClusterVariables( + const GraphNodeSet& cluster, + const std::unordered_set& deny_var_set, + GraphNodeSet* cluster_inputs, GraphNodeSet* cluster_outputs, + GraphNodeSet* cluster_internals) { // collecting all input and output of op for (auto* op_node : cluster) { + const auto& op_name = op_node->Name(); for (auto* input_var_node : op_node->inputs) { - cluster_inputs->insert(input_var_node); + if (!deny_var_set.count(input_var_node->Name())) { + // ignore deny var node + cluster_inputs->insert(input_var_node); + } } for (auto* output_var_node : op_node->outputs) { - cluster_outputs->insert(output_var_node); + if (!deny_var_set.count(output_var_node->Name())) { + cluster_outputs->insert(output_var_node); + } } } // remove output node from cluster_inputs, @@ -294,22 +358,25 @@ void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs, void AddCinnOpToGraph(const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs, - const std::string& compilation_key, Graph* graph) { + const std::string& compilation_key, + const std::unordered_set& deny_var_set, + Graph* graph) { // Add the cinn launch op framework::OpDesc cinn_op_desc; cinn_op_desc.SetType(kCinnLaunchOp); std::vector input_names; + std::for_each(cluster_inputs.begin(), cluster_inputs.end(), - [&input_names](Node* n) { - if (n->Var() != nullptr) { + [&input_names, &deny_var_set](Node* n) { + if (n->Var() != nullptr && !deny_var_set.count(n->Name())) { input_names.emplace_back(n->Name()); } }); cinn_op_desc.SetInput("X", input_names); std::vector output_names; std::for_each(cluster_outputs.begin(), cluster_outputs.end(), - [&output_names](Node* n) { - if (n->Var() != nullptr) { + [&output_names, &deny_var_set](Node* n) { + if (n->Var() != nullptr && !deny_var_set.count(n->Name())) { output_names.emplace_back(n->Name()); } }); @@ -341,15 +408,14 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster, // kCinnLaunchOp, and inputs ares cluster_inputs and outputs are // cluster_outputs. // Meanwhile, move all links of cluster to the cinn op. -void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster, - const GraphNodeSet& cluster_inputs, - const GraphNodeSet& cluster_outputs, - const GraphNodeSet& cluster_internals, - const std::string& compilation_key, - Graph* graph) { +void ReplaceSubGraphWithCinnOpNode( + const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_internals, + const std::string& compilation_key, + const std::unordered_set& deny_var_set, Graph* graph) { // Add the cinn op node whose name is "kCinnLaunchOp" into graph AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key, - graph); + deny_var_set, graph); // Remove the cinn subgraph from graph RemoveSubGraphFromGraph(cluster, cluster_internals, graph); } @@ -398,9 +464,11 @@ void SearchAllSubgraphs(Graph* graph) { // Classify var node to inputs, outputs, and internals. GraphNodeSet cluster_set(node_vec.begin(), node_vec.end()); + auto deny_var_set = GetDenyVarNames(cluster_set); + GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals; - AnalyseClusterVariables(cluster_set, &cluster_inputs, &cluster_outputs, - &cluster_internals); + AnalyseClusterVariables(cluster_set, deny_var_set, &cluster_inputs, + &cluster_outputs, &cluster_internals); VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set); VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs); @@ -417,7 +485,8 @@ void SearchAllSubgraphs(Graph* graph) { // Replace the found cluster to a new cinn op node ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs, - cluster_internals, compilation_key, graph); + cluster_internals, compilation_key, + deny_var_set, graph); } } } // namespace