From ea65e61cb1d06e5ee2752ce4b53ae636aa2d0ce1 Mon Sep 17 00:00:00 2001 From: wenchunjiang Date: Sat, 16 May 2020 11:21:21 +0800 Subject: [PATCH] adding new constructKernelGraph to transform all subgraphs to kernel_graph from root func_graph --- .../ccsrc/session/anf_runtime_algorithm.cc | 10 + .../ccsrc/session/anf_runtime_algorithm.h | 1 + mindspore/ccsrc/session/session_basic.cc | 184 +++++++++++++++++- mindspore/ccsrc/session/session_basic.h | 4 + mindspore/ccsrc/utils/utils.h | 4 + 5 files changed, 202 insertions(+), 1 deletion(-) mode change 100755 => 100644 mindspore/ccsrc/session/session_basic.cc diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 7260bb46d..7e7f2fd40 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -909,5 +909,15 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { auto kernel_name = AnfAlgo::GetCNodeName(node); return kernel_name == kGetNextOpName; } + +FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + auto func_graph = value->cast(); + return func_graph; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index be88075f4..f9b426261 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -181,6 +181,7 @@ class AnfRuntimeAlgorithm { static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsGetNext(const NotNull &node); + static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc old mode 100755 new mode 100644 index b3267db6c..00b3e14fc --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -451,6 +451,126 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K return new_cnode; } +CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + std::vector cnode_inputs; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + if (IsValueNode(attr_input)) { + // create primitive of cnode:call + cnode_inputs = {std::make_shared(std::make_shared(kCallOpName))}; + // create a ValueNode as input of cnode:call + if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)); + } else { + auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph); + if (new_value_node != nullptr) { + cnode_inputs.emplace_back(new_value_node); + } + } + } else if (attr_input->isa()) { + // create primitive of cnode:call(switch) + cnode_inputs = {std::make_shared(std::make_shared(kCallOpName))}; + if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { + auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); + auto prim = GetCNodePrimitive(cnode_input); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() != kSwitchOpName) { + MS_LOG(EXCEPTION) << "CNode input[0] must be switch."; + } + cnode_inputs.emplace_back(cnode_input); + } else { + MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() + << ", but input[0] has not been created."; + } + } else { + // get primitive of old node + auto prim = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(prim); + // push attr to inputs[0] of new cnode + cnode_inputs = {std::make_shared(std::make_shared(*prim))}; + } + + for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { + auto anf = cnode->inputs()[input_idx]; + MS_EXCEPTION_IF_NULL(anf); + // anf has been created before + if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + continue; + } else if (anf->isa()) { + if (!IsValueNode(anf)) { + // if input is a common value node, + auto new_value_node = CreateNewValueNode(anf, graph); + if (new_value_node != nullptr) { + cnode_inputs.emplace_back(new_value_node); + } + } else { + // if input is a ValueNode + auto new_value_node = CreateValueNodeKernelGraph(anf, graph); + if (new_value_node != nullptr) { + cnode_inputs.emplace_back(new_value_node); + } + } + continue; + } else if (anf->isa()) { + auto new_parameter = CreateNewParameter(anf, graph); + cnode_inputs.push_back(new_parameter); + continue; + } + MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; + } + TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); + auto new_cnode = graph->NewCNode(cnode_inputs); + TraceManager::EndTrace(); + return new_cnode; +} + +ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + auto value_node = anf->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); + MS_EXCEPTION_IF_NULL(sub_func_graph); + if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) { + MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph."; + } + auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph]; + + ValueNodePtr new_value_node = std::make_shared(sub_kernel_graph); + new_value_node->set_abstract(value_node->abstract()); + // create new kernel_info of new value_node + auto kernel_info = std::make_shared(); + kernel_info->SetFeatureMapFlag(false); + new_value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get()); + + graph->FrontBackendlMapAdd(anf, new_value_node); + graph->AddValueNodeToGraph(new_value_node); + + return new_value_node; +} + +ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + if (!anf->isa()) { + MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; + } + + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + + auto new_parameter = graph->NewParameter(anf->cast()); + graph_inputs->push_back(new_parameter); + graph->FrontBackendlMapAdd(anf, new_parameter); + + return new_parameter; +} + KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { std::unordered_map other_graph_cnode; auto graph = NewKernelGraph(); @@ -494,7 +614,69 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con return graph; } -std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &) { return nullptr; } +std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) { + MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph."; + return front_backend_graph_map_[func_graph]; + } + auto node_list = TopoSort(func_graph->get_return()); + auto graph = NewKernelGraph(); + front_backend_graph_map_[func_graph] = graph; + MS_LOG(INFO) << "Create graph: " << graph->graph_id(); + + for (const auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); + if (!node->isa()) { + MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode"; + continue; + } else { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + + // recurse control ops: call, partial + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + if (IsValueNode(attr_input)) { + // recurse call subgraph + auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input); + ConstructKernelGraph(sub_func_graph); + } else if (IsValueNode(attr_input)) { + auto prim = GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == kPartialOpName) { + // recurse partial subgraph + auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex); + MS_EXCEPTION_IF_NULL(func_graph_node); + auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node); + ConstructKernelGraph(sub_func_graph); + } + } + + // create a new cnode object + auto new_cnode = CreateNewCNode(cnode, graph.get()); + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_scope(cnode->scope()); + graph->FrontBackendlMapAdd(node, new_cnode); + + // set original return to kernel_graph + if (IsPrimitive(new_cnode->input(kAnfPrimitiveIndex), prim::kPrimReturn)) { + graph->set_return(new_cnode); + } + } + } + + MS_EXCEPTION_IF_NULL(context_); + FuncGraphManagerPtr manager = context_->manager(); + if (manager) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } + graph->SetExecOrderByDefault(); + return graph; +} // run graph steps void SessionBasic::LoadInputData(const std::shared_ptr &kernel_graph, diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index efd27743f..885ca8da4 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -78,6 +78,7 @@ class SessionBasic { CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, std::unordered_map *other_graph_cnode); + CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); // set parameters of final graph virtual GraphId SetFinalGraphInput(const std::vector &) { return kInvalidGraphId; } @@ -111,9 +112,12 @@ class SessionBasic { // create a new kernel graph and update the graph sum KernelGraphPtr NewKernelGraph(); ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); + ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); std::unordered_map> graphs_; std::unordered_map> run_op_graphs_; + std::unordered_map front_backend_graph_map_; std::shared_ptr context_; CallBackFunc summary_callback_; static GraphId graph_sum_; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index a0eb2740d..e3169daaa 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -141,6 +141,9 @@ constexpr auto kLabelSetOpName = "LabelSet"; constexpr auto kLabelSwitchOpName = "LabelSwitch"; constexpr auto kLabelGotoOpName = "LabelGoto"; constexpr auto kBNInferGradOpName = "BNInferGrad"; +constexpr auto kCallOpName = "call"; +constexpr auto kPartialOpName = "partial"; +constexpr auto kSwitchOpName = "switch"; // attr key name constexpr auto kAttrInputNames = "input_names"; @@ -196,6 +199,7 @@ const size_t kMemAlignSize = 512; // define special index in special node constexpr auto kAnfPrimitiveIndex = 0; +constexpr auto kAnfPartialFuncGraphIndex = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kTupleGetItemInputSize = 3; -- GitLab