From c93331c535f982b2b937c3d54eb16840334778f3 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 28 Oct 2021 16:50:24 +0800 Subject: [PATCH] Fix several bugs for enabling Paddle to train with CINN. (#36739) * Update the content of `test_parallel_executor_run_cinn.py`. * Fix some bugs in the topological sort and `CreateNewSubGraph`. * Update the CINN commit id used by Paddle. * Update the unit test to `add+relu`. * Update according to reviewers' suggestion. --- cmake/external/cinn.cmake | 2 +- .../fluid/framework/details/build_strategy.cc | 16 +- .../framework/paddle2cinn/CMakeLists.txt | 4 +- .../framework/paddle2cinn/build_cinn_pass.cc | 240 ++++++++---------- .../framework/paddle2cinn/build_cinn_pass.h | 2 +- .../paddle2cinn/cinn_graph_symbolization.cc | 74 +++++- .../paddle2cinn/cinn_graph_symbolization.h | 5 +- .../cinn_graph_symbolization_test.cc | 2 +- .../test_parallel_executor_run_cinn.py | 98 +++++-- 9 files changed, 269 insertions(+), 174 deletions(-) diff --git a/cmake/external/cinn.cmake b/cmake/external/cinn.cmake index ee5aea9f8b2..effc0b67ff6 100644 --- a/cmake/external/cinn.cmake +++ b/cmake/external/cinn.cmake @@ -27,7 +27,7 @@ add_definitions(-w) include(ExternalProject) set(CINN_SOURCE_DIR ${THIRD_PARTY_PATH}/CINN) # TODO(zhhsplendid): Modify git tag after we have release tag -set(CINN_GIT_TAG e422c01b7875301996a2baf67a14ba61b0e6192a) +set(CINN_GIT_TAG cb030430d76f42f7310d09608f9b22959ecbcb51) set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} -DPUBLISH_LIBS=ON -DWITH_TESTING=ON) set(CINN_BUILD_COMMAND $(MAKE) cinnapi -j) ExternalProject_Add( diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 1bb1ae0ea67..cee97820d6a 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -52,6 +52,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ResolveOptionConfliction(); AppendPrintGraphPass("graph_viz_pass", "_original_graph"); + +#ifdef PADDLE_WITH_CINN + if (FLAGS_use_cinn) { + // Note: This pass is used to enable cinn. + AppendPass("build_cinn_pass"); + AppendPrintGraphPass("graph_viz_pass", "_build_cinn_graph"); + } +#endif + AppendPassWithCheck(strategy_.enable_sequential_execution_, "sequential_execution_pass"); AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass"); @@ -74,13 +83,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Note: This pass is used to check whether the multi_device_graph is right. AppendPass("multi_devices_check_pass"); -#ifdef PADDLE_WITH_CINN - if (FLAGS_use_cinn) { - // Note: This pass is used to enable cinn. - AppendPass("build_cinn_pass"); - } -#endif - SetCollectiveContext(); } diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index 04931c7c4b3..e5dac1aa629 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -1,7 +1,7 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc) -cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector cinn_compiler) +cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce) cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn) -cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn) +cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn) cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn) cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 0664a63c2b7..fd668179616 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -26,9 +26,13 @@ limitations under the License. */ #include "cinn/frontend/op_mapper_registry.h" #include "cinn/frontend/op_mappers/use_op_mappers.h" #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/subgraph_detector.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" namespace paddle { namespace framework { @@ -40,11 +44,28 @@ using framework::ir::Node; using GraphNodeVec = std::vector; using GraphNodeSet = std::unordered_set; +namespace { +int ExtractOpRole(const GraphNodeSet& cluster) { + std::unordered_set op_roles; + std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName(); + for (auto* n : cluster) { + if (n->Op() && n->Op()->HasAttr(attr_name)) { + op_roles.insert(BOOST_GET_CONST(int, n->Op()->GetAttr(attr_name))); + } + } + if (op_roles.size() == 1U) { + return *(op_roles.begin()); + } else { + return static_cast(OpRole::kNotSpecified); + } +} + // Deal with subgraph's feed input var node: // create a new input var node and it's feed op node void AddFeedOpAndVar(const std::unordered_set& feed_vars, const GraphNodeSet& cluster, const std::unordered_map& old_op2new_op, + const std::unordered_map& old_var2new_var, Graph* graph) { for (auto* old_var : feed_vars) { // create feed op @@ -53,21 +74,19 @@ void AddFeedOpAndVar(const std::unordered_set& feed_vars, desc.SetOutput("Out", {old_var->Name()}); auto op = graph->CreateOpNode(&desc); - // create new feed var node (SSAGraph) - auto var = graph->CreateVarNode(old_var->Var()); + // get new feed var node + auto* var = old_var2new_var.at(old_var); // link feed op and feed var - op->outputs = {var}; - var->inputs = {op}; + IR_NODE_LINK_TO(op, var); // link feed var to cluster op for (auto* old_op : old_var->outputs) { if (cluster.count(old_op)) { - var->outputs.emplace_back(old_op2new_op.at(old_op)); - old_op2new_op.at(old_op)->inputs.emplace_back(var); + IR_NODE_LINK_TO(var, old_op2new_op.at(old_op)); } // Do not need relink old op or old var here, they will be - // fixed in RemoveLinkFromCluster, here we just deal with + // fixed in RemoveSubGraphFromGraph, here we just deal with // new subgraph's node. } } @@ -79,14 +98,14 @@ void AddFeedOpAndVar(const std::unordered_set& feed_vars, void AddParamVar(const std::unordered_set& param_vars, const GraphNodeSet& cluster, const std::unordered_map& old_op2new_op, + const std::unordered_map& old_var2new_var, Graph* graph) { for (auto* old_var : param_vars) { - auto var = graph->CreateVarNode(old_var->Var()); + auto* var = old_var2new_var.at(old_var); for (auto* old_op : old_var->outputs) { if (cluster.count(old_op)) { - var->outputs.emplace_back(old_op2new_op.at(old_op)); - old_op2new_op.at(old_op)->inputs.emplace_back(var); + IR_NODE_LINK_TO(var, old_op2new_op.at(old_op)); } } } @@ -97,14 +116,14 @@ void AddParamVar(const std::unordered_set& param_vars, void AddOutputVar(const std::unordered_set& output_vars, const GraphNodeSet& cluster, const std::unordered_map& old_op2new_op, + const std::unordered_map& old_var2new_var, Graph* graph) { for (auto* old_var : output_vars) { - auto var = graph->CreateVarNode(old_var->Var()); + auto* var = old_var2new_var.at(old_var); for (auto* old_op : old_var->inputs) { if (cluster.count(old_op)) { - var->inputs.emplace_back(old_op2new_op.at(old_op)); - old_op2new_op.at(old_op)->outputs.emplace_back(var); + IR_NODE_LINK_TO(old_op2new_op.at(old_op), var); } } } @@ -128,14 +147,25 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, std::unordered_map old_var2new_var; for (auto* var : cluster_internals) { - Node* sub_node; - if (var->Var() == nullptr) { - sub_node = subgraph->CreateEmptyNode(var->Name(), var->NodeType()); - } else { - sub_node = subgraph->CreateVarNode(var->Var()); - } + PADDLE_ENFORCE_NOT_NULL(var->Var(), + platform::errors::PreconditionNotMet( + "The var desc of the node in cluster_internals " + "shouldn't be null.")); + auto* sub_node = subgraph->CreateVarNode(var->Var()); old_var2new_var[var] = sub_node; } + for (auto* var : cluster_inputs) { + if (var->Var()) { + auto* sub_node = subgraph->CreateVarNode(var->Var()); + old_var2new_var[var] = sub_node; + } + } + for (auto* var : cluster_outputs) { + if (var->Var()) { + auto* sub_node = subgraph->CreateVarNode(var->Var()); + old_var2new_var[var] = sub_node; + } + } std::unordered_set need_feed_vars; std::unordered_set param_vars, output_vars; @@ -144,8 +174,10 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, // out-graph. for (auto* op : cluster) { for (auto* var : op->inputs) { - if (cluster_internals.count(var)) { - old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]); + // one output var maybe an input of the cluster + if (cluster_internals.count(var) || + (cluster_outputs.count(var) && old_var2new_var.count(var))) { + IR_NODE_LINK_TO(old_var2new_var.at(var), old_op2new_op.at(op)); } else if (cluster_inputs.count(var) && var->Var() != nullptr) { if (var->Var()->IsParameter()) { // Parameters have been preserved in scope, compared to feed var, @@ -162,7 +194,7 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, } for (auto* var : op->outputs) { if (cluster_internals.count(var)) { - old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]); + IR_NODE_LINK_TO(old_op2new_op.at(op), old_var2new_var.at(var)); } else if (cluster_outputs.count(var) && var->Var() != nullptr) { // Create new output var node to guarantee the independency of // subgraph. In other words, the subgraph has no connection with @@ -172,22 +204,12 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, } } - AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, subgraph.get()); - AddParamVar(param_vars, cluster, old_op2new_op, subgraph.get()); - AddOutputVar(output_vars, cluster, old_op2new_op, subgraph.get()); - - for (auto* var : cluster_internals) { - for (auto* op : var->inputs) { - if (cluster.count(op)) { - old_var2new_var[var]->inputs.emplace_back(old_op2new_op[op]); - } - } - for (auto* op : var->outputs) { - if (cluster.count(op)) { - old_var2new_var[var]->outputs.emplace_back(old_op2new_op[op]); - } - } - } + AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, old_var2new_var, + subgraph.get()); + AddParamVar(param_vars, cluster, old_op2new_op, old_var2new_var, + subgraph.get()); + AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var, + subgraph.get()); return subgraph; } @@ -238,12 +260,26 @@ void AnalyseClusterVariables(const GraphNodeSet& cluster, } } -Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs, - const GraphNodeSet& cluster_outputs, - const std::string& compilation_key, Graph* graph) { - // add special cinn op - framework::OpDesc special_op_desc; - special_op_desc.SetType(kCinnLaunchOp); +void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs, Node* cinn_op_node) { + // add new link from cluster_inputs to cinn_op_node + for (auto* var_node : cluster_inputs) { + IR_NODE_LINK_TO(var_node, cinn_op_node); + } + + // add new link from cinn_op_node to cluster_outputs + for (auto* var_node : cluster_outputs) { + IR_NODE_LINK_TO(cinn_op_node, var_node); + } +} + +void AddCinnOpToGraph(const GraphNodeSet& cluster, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs, + const std::string& compilation_key, Graph* graph) { + // Add the cinn launch op + framework::OpDesc cinn_op_desc; + cinn_op_desc.SetType(kCinnLaunchOp); std::vector input_names; std::for_each(cluster_inputs.begin(), cluster_inputs.end(), [&input_names](Node* n) { @@ -251,7 +287,7 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs, input_names.emplace_back(n->Name()); } }); - special_op_desc.SetInput("X", input_names); + cinn_op_desc.SetInput("X", input_names); std::vector output_names; std::for_each(cluster_outputs.begin(), cluster_outputs.end(), [&output_names](Node* n) { @@ -259,96 +295,42 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs, output_names.emplace_back(n->Name()); } }); - special_op_desc.SetOutput("Out", output_names); - special_op_desc.SetAttr(kCompilationKey, compilation_key); - special_op_desc.Flush(); - auto* special_op_node = graph->CreateOpNode(&special_op_desc); - special_op_node->inputs.assign(cluster_inputs.begin(), cluster_inputs.end()); - special_op_node->outputs.assign(cluster_outputs.begin(), - cluster_outputs.end()); - return special_op_node; -} - -void AddLinkToSpecialOp(const GraphNodeSet& cluster_inputs, - const GraphNodeSet& cluster_outputs, - Node* special_op_node) { - // add new link from cluster_inputs to special_op_node - for (auto* var_node : cluster_inputs) { - var_node->outputs.push_back(special_op_node); - } - - // add new link from special_op_node to cluster_outputs - for (auto* var_node : cluster_outputs) { - var_node->inputs.push_back(special_op_node); - } -} - -void RemoveLinkFromCluster(const GraphNodeSet& cluster, - const GraphNodeSet& cluster_inputs, - const GraphNodeSet& cluster_outputs) { - // remove all nodes in cluster - auto get_preserved_ops = [&cluster](const GraphNodeVec& ops) { - GraphNodeVec nodes; - for (auto* op_node : ops) { - if (cluster.find(op_node) == cluster.end()) { - nodes.emplace_back(op_node); - } - } - return nodes; - }; - - // removing useless link from cluster_inputs to cluster - for (auto* var_node : cluster_inputs) { - auto preserved_ops = get_preserved_ops(var_node->outputs); - var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end()); - // According to SSA form, a var node must not be any two op's output, - // and the cluster_inputs var nodes is defined as an out-graph op's - // output, so the cluster_inputs var nodes are not any subgraph op's - // output. Do not reassign input list here. - } - - // removing useless link from cluster to cluster_outputs - for (auto* var_node : cluster_outputs) { - auto preserved_ops = get_preserved_ops(var_node->inputs); - var_node->inputs.assign(preserved_ops.begin(), preserved_ops.end()); - - // Note that cluster_outputs var node maybe some subgraph op's input, - // here we need remove them. - preserved_ops = get_preserved_ops(var_node->outputs); - var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end()); - } + cinn_op_desc.SetOutput("Out", output_names); + cinn_op_desc.SetAttr(kCompilationKey, compilation_key); + cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + ExtractOpRole(cluster)); + cinn_op_desc.Flush(); + auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc); + // Add new links from or to the the cinn launch op node + AddLinkToCinnOp(cluster_inputs, cluster_outputs, cinn_op_node); } // Removing cluster node and internals node from Graph void RemoveSubGraphFromGraph(const GraphNodeSet& cluster, const GraphNodeSet& cluster_internals, Graph* graph) { - for (auto* op_node : cluster) { - graph->RemoveNode(op_node); - } - for (auto* var_node : cluster_internals) { - graph->RemoveNode(var_node); - } + const std::unordered_set const_cluster{cluster.cbegin(), + cluster.cend()}; + const std::unordered_set const_internals{ + cluster_internals.cbegin(), cluster_internals.cend()}; + ir::GraphSafeRemoveNodes(graph, const_cluster); + ir::GraphSafeRemoveNodes(graph, const_internals); } -// Replacing Cinn subgraph to a special op node, whose op_type is +// Replacing Cinn subgraph to a cinn op node, whose op_type is // kCinnLaunchOp, and inputs ares cluster_inputs and outputs are // cluster_outputs. -// Meanwhile, move all links of cluster to the special op. -void ReplaceSubGraphWithSpecialOpNode(const GraphNodeSet& cluster, - const GraphNodeSet& cluster_inputs, - const GraphNodeSet& cluster_outputs, - const GraphNodeSet& cluster_internals, - const std::string& compilation_key, - Graph* graph) { - // First, add the special op node whose name is "kCinnLaunchOp" into graph - auto special_op_node = AddSpecialOpToGraph(cluster_inputs, cluster_outputs, - compilation_key, graph); - // Second, remove all graph's links which are from or to cluster nodes - RemoveLinkFromCluster(cluster, cluster_inputs, cluster_outputs); - // Third, add new links from or to the the special op node - AddLinkToSpecialOp(cluster_inputs, cluster_outputs, special_op_node); - // Finally, remove the cinn sub graph from graph +// Meanwhile, move all links of cluster to the cinn op. +void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs, + const GraphNodeSet& cluster_internals, + const std::string& compilation_key, + Graph* graph) { + // Add the cinn op node whose name is "kCinnLaunchOp" into graph + AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key, + graph); + // Remove the cinn subgraph from graph RemoveSubGraphFromGraph(cluster, cluster_internals, graph); } @@ -376,12 +358,12 @@ void SearchAllSubgraphs(Graph* graph) { // save it in CinnCompiler std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph( cluster_set, cluster_internals, cluster_inputs, cluster_outputs)); - // Replace the found cluster to a new special op node - ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs, - cluster_outputs, cluster_internals, - compilation_key, graph); + // Replace the found cluster to a new cinn op node + ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs, + cluster_internals, compilation_key, graph); } } +} // namespace void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); } diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 556ff228915..1c07fb314e9 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -20,7 +20,7 @@ namespace paddle { namespace framework { namespace paddle2cinn { -constexpr char kCinnLaunchOp[] = "CinnLaunchOp"; +constexpr char kCinnLaunchOp[] = "cinn_launch"; constexpr char kCompilationKey[] = "compilation_key"; // A pass named BuildCinnPass, the function of this pass is: diff --git a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc index e4e16498b84..793a9497da2 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc @@ -15,16 +15,18 @@ limitations under the License. */ #include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h" #include -#include #include +#include +#include #include -#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/paddle2cinn/transform_desc.h" #include "paddle/fluid/framework/variable.h" #include "cinn/frontend/op_mappers/use_op_mappers.h" #include "cinn/frontend/var_type_utils.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" namespace paddle { namespace framework { @@ -86,35 +88,93 @@ CinnGraphSymbolization::GetGraphInputParameterNames() const { // Transform paddle scope to cinn, note that we only preserve the graph’s // input parameter variable and ignore others. std::shared_ptr<::cinn::hlir::framework::Scope> -CinnGraphSymbolization::CreateCinnScope(const FeedInfoMap& feed_map) const { +CinnGraphSymbolization::CreateCinnScope(const FeedInfoMap& feed_map) { auto cinn_scope = ::cinn::hlir::framework::Scope::Create(); // get the graph's input parameter variable name list auto parameter_names = GetGraphInputParameterNames(); for (const auto& param_name : parameter_names) { - VLOG(4) << "add param var [" << param_name << "] info scope"; // if cannot find var in graph input, skip. // scope accepte the CINN format name, so here we need transform // paddle format name to CINN format. - auto* cinn_var = cinn_scope->Var( - ::cinn::utils::TransValidVarName(param_name)); + auto valid_name = ::cinn::utils::TransValidVarName(param_name); + auto* cinn_var = cinn_scope->Var(valid_name); auto& cinn_tensor = absl::get(*cinn_var); // here we only need preserve dtype and shape, do not need preserve data auto feed_info = feed_map.at(param_name); cinn_tensor->set_type(feed_info.type); cinn_tensor->Resize(::cinn::hlir::framework::Shape(feed_info.shape)); + VLOG(4) << "add paddle param var [" << param_name + << "] info cinn scope var[" << valid_name << "]"; + var_model_to_program_map_[param_name] = valid_name; } return cinn_scope; } +std::vector CinnGraphSymbolization::TopologicalSort() const { + std::unordered_set op_nodes; + std::for_each(graph_.Nodes().begin(), graph_.Nodes().end(), + [&op_nodes](Node* n) { + if (n->IsOp()) { + op_nodes.emplace(n); + } + }); + + std::unordered_map> adj_list; + std::unordered_map in_degrees; + for (auto* n : op_nodes) { + // the op's input is var + for (auto* in_var : n->inputs) { + // the var's input is op + for (auto* in_op : in_var->inputs) { + if (op_nodes.count(in_op)) { + ++adj_list[in_op][n]; + ++in_degrees[n]; + } + } + } + } + + // find topology entries + std::queue queue; + for (auto* n : op_nodes) { + if (!in_degrees[n]) { + queue.push(n); + } + } + + // topological sorting + std::vector sorted_ops; + while (!queue.empty()) { + auto* cur_op = queue.front(); + queue.pop(); + + VLOG(4) << "topological sort insert: " << cur_op->Name() << " " + << reinterpret_cast(cur_op) << " input " + << cur_op->inputs.size(); + sorted_ops.emplace_back(cur_op); + for (const auto& adj_pair : adj_list[cur_op]) { + in_degrees.at(adj_pair.first) -= adj_pair.second; + if (!in_degrees[adj_pair.first]) { + queue.push(adj_pair.first); + } + } + } + + PADDLE_ENFORCE_EQ(sorted_ops.size(), op_nodes.size(), + platform::errors::PreconditionNotMet( + "The sorting graph contains cycles.")); + return sorted_ops; +} + std::vector> CinnGraphSymbolization::TransformAllGraphOpToCinn() const { std::vector> cinn_op_descs; - const auto& sorted_ops = ir::TopologySortOperations(graph_); + auto sorted_ops = TopologicalSort(); for (auto* node : sorted_ops) { cinn_op_descs.emplace_back(std::make_unique()); auto& cinn_desc = cinn_op_descs.back(); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h index b6b4b24c6ee..af60493044c 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h @@ -102,6 +102,9 @@ class CinnGraphSymbolization { // transform all paddle var desc in feed list into cinn_var_descs_ FeedInfoMap GetFeedInfoMapFromInput() const; + // get the topological sort of the graph_ + std::vector TopologicalSort() const; + // transform all paddle op desc in graph into cinn op desc std::vector> TransformAllGraphOpToCinn() const; @@ -115,7 +118,7 @@ class CinnGraphSymbolization { // create cinn scope and add parameter's feed info into scope std::shared_ptr<::cinn::hlir::framework::Scope> CreateCinnScope( - const FeedInfoMap& feed_map) const; + const FeedInfoMap& feed_map); // get the graph op's input persistable var name set std::unordered_set GetGraphInputParameterNames() const; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization_test.cc index 940228314a1..be2ca2f73e1 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization_test.cc @@ -268,7 +268,7 @@ TEST_F(CinnGraphSymbolizationTest, sortgraph) { sort_names.emplace_back(desc->Type()); } ASSERT_EQ(sort_names, - std::vector({"feed", "mul", "feed", "add", "relu"})); + std::vector({"feed", "feed", "mul", "add", "relu"})); } TEST_F(CinnGraphSymbolizationTest, runop) { diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py index d26c7a1bb44..601da32cfb1 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py @@ -16,14 +16,17 @@ from __future__ import print_function import logging import numpy as np +import os import paddle +import shutil +import tempfile import unittest paddle.enable_static() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) -logger = logging.getLogger(__name__) +logger = logging.getLogger("paddle_with_cinn") def set_cinn_flag(val): @@ -36,34 +39,79 @@ def set_cinn_flag(val): return cinn_compiled +def reader(limit): + for i in range(limit): + yield np.ones([1, 28]).astype('float32') * (i * 3.14 / (i + 1)), \ + np.array([i + 1]).astype('int64') + + +def rand_data(img, label, loop_num=10): + feed = [] + data = reader(loop_num) + for _ in range(loop_num): + d, l = next(data) + feed.append({img: d, label: l}) + return feed + + +def build_program(main_program, startup_program): + with paddle.static.program_guard(main_program, startup_program): + img = paddle.static.data(name='img', shape=[1, 28], dtype='float32') + param = paddle.create_parameter( + name="bias", + shape=[1, 28], + dtype="float32", + attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign( + np.ones([1, 28]).astype(np.float32)))) + label = paddle.static.data(name="label", shape=[1], dtype='int64') + + hidden = paddle.add(img, param) + prediction = paddle.nn.functional.relu(hidden) + + loss = paddle.nn.functional.cross_entropy(input=prediction, label=label) + avg_loss = paddle.mean(loss) + adam = paddle.optimizer.Adam(learning_rate=0.001) + adam.minimize(avg_loss) + return img, label, avg_loss + + +def do_test(dot_save_dir): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + img, label, loss = build_program(main_program, startup_program) + + place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_program) + + build_strategy = paddle.static.BuildStrategy() + build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, "viz") + compiled_program = paddle.static.CompiledProgram( + main_program, build_strategy).with_data_parallel(loss_name=loss.name) + + iters = 1 + feed = rand_data(img.name, label.name, iters) + for step in range(iters): + loss_v = exe.run(compiled_program, + feed=feed[step], + fetch_list=[loss], + return_merged=False) + logger.info("loss value = {}".format(loss_v)) + + @unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.") class TestParallelExecutorRunCinn(unittest.TestCase): - def test_run_from_cinn(self): - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): - data = paddle.static.data( - name='X', shape=[None, 1], dtype='float32') - prediction = paddle.static.nn.fc(data, 2) - loss = paddle.mean(prediction) - adam = paddle.optimizer.Adam() - adam.minimize(loss) - - place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( - ) else paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(startup_program) - compiled_program = paddle.static.CompiledProgram( - main_program).with_data_parallel(loss_name=loss.name) - - batch_size = 16 - x = np.random.random(size=(batch_size, 1)).astype('float32') - fetch = exe.run(compiled_program, - feed={'X': x}, - fetch_list=[prediction.name], - return_merged=False) + def setUp(self): + set_cinn_flag(True) + self.tmpdir = tempfile.mkdtemp(prefix="dots_") + def tearDown(self): set_cinn_flag(False) + shutil.rmtree(self.tmpdir) + + def test_run_with_cinn(self): + do_test(self.tmpdir) if __name__ == '__main__': -- GitLab