diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index fe2b2eb3f1ccddb1963b5bf4e8e33653cb6f2c8e..b0d9a96d47ccf89a9e4300822794a7fb507805fe 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -942,7 +942,6 @@ std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { auto switch_node = input1->cast(); MS_EXCEPTION_IF_NULL(switch_node); - MS_LOG(INFO) << "switch : " << switch_node->DebugString(); auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr { auto partial = switch_node->input(input_index); MS_EXCEPTION_IF_NULL(partial); @@ -950,7 +949,6 @@ std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN MS_EXCEPTION_IF_NULL(partial_cnode); auto graph_node = partial_cnode->input(1); MS_EXCEPTION_IF_NULL(graph_node); - MS_LOG(INFO) << graph_node->DebugString(); auto graph_value_node = graph_node->cast(); MS_EXCEPTION_IF_NULL(graph_value_node); auto graph_value = graph_value_node->value(); @@ -976,5 +974,17 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { } MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); } + +bool AnfRuntimeAlgorithm::IsWhileTrueGraph(const KernelGraphPtr &child_graph) { + auto call_nodes = child_graph->FindNodeByPrimitive(prim::kPrimCall); + for (const auto &call_node : call_nodes) { + auto graphs = GetCallNodeKernelGraph(call_node); + if (graphs.size() == 1 && graphs[0] == child_graph->parent_graph()) { + return true; + } + } + return false; +} + } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 10ae5282e0a4a7a8e798eabb39439a1b4b114cbf..950a8976543af95366069249a4289e490bbd6f5e 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -185,6 +185,7 @@ class AnfRuntimeAlgorithm { static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); static bool IsSwitchCall(const CNodePtr &call_node); + static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index f4f2f4bb5fd43aac68d795ee882e7fb3e6186d77..9fe9fc9f4bc5bbbf66f4a3c02c11cf779e04972c 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "operator/ops.h" #include "ir/meta_tensor.h" #include "ir/anf.h" @@ -160,7 +161,7 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { std::vector GetCNodes(const std::vector &anf_nodes) { std::vector cnodes = {}; size_t i = 0; - for (auto anf : anf_nodes) { + for (const auto &anf : anf_nodes) { MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); MS_EXCEPTION_IF_NULL(anf); if (anf->isa()) { @@ -192,6 +193,8 @@ std::vector> GetChildList(const KernelGraph &cur_graph, co return ret; } +// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of +// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] void UpdateRealInput(KernelGraph *graph) { auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); auto bind_call_partial_with_parameter = [&](const std::vector ¶meters, @@ -239,6 +242,15 @@ void UpdateRealInput(KernelGraph *graph) { } } } + +void RecurseToUpdateCallRealInput(KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "start graph id:" << graph->graph_id(); + graph->UpdateCallRealInput(); + for (auto &child_graph : graph->child_graph_order()) { + RecurseToUpdateCallRealInput(child_graph.get()); + } +} } // namespace GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { @@ -254,7 +266,7 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { MS_LOG(INFO) << "start"; auto graph = ConstructKernelGraph(func_graph); // split switch - SplitGraph(graph); + SplitGraphs(graph); // insert goto labels and label_sets LinkChildGraphs(NOT_NULL(graph)); // resource initialize @@ -1366,8 +1378,8 @@ void AscendSession::SyncInitialTenosrToDevice() { } } -KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list) { +KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list) { MS_EXCEPTION_IF_NULL(new_kernel_graph); MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id(); // count the output of every anf node @@ -1376,9 +1388,6 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_ for (auto &input : anf_node->inputs()) { (void)has_output_nodes.insert(input); } - if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { - new_kernel_graph->set_return(anf_node->cast()); - } } MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); // create new parameter from cnode @@ -1386,6 +1395,7 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_ auto cnode = anf_node->cast(); for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { auto input = cnode->inputs()[input_idx]; + MS_EXCEPTION_IF_NULL(input); if (!input->isa()) { cnode->set_input(input_idx, input); continue; @@ -1417,6 +1427,12 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_ return new_kernel_graph; } +void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) { + SplitGraph(root_graph); + // replace the real input if the real input is a call + RecurseToUpdateCallRealInput(root_graph.get()); +} + void AscendSession::SplitGraph(const KernelGraphPtr &graph) { MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); MS_EXCEPTION_IF_NULL(graph); @@ -1426,6 +1442,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { // get child list from current graph std::vector> child_graph_lists = GetChildList(*graph, apply_list); auto bind_new_call_to_new_graph = [&](std::vector child_graph_list) -> AnfNodePtr { + // if child graph list only has a call ,then return the exist call if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { return child_graph_list[0]; } @@ -1440,22 +1457,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { for (auto &child_graph_node : child_graph_list) { AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); } - SplitKernelGraph(child_graph, child_graph_list); + ConstructSplitedGraph(child_graph, child_graph_list); auto new_call = graph->NewCNode(new_call_input); AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); return new_call; }; if (child_graph_lists.size() > 1) { + std::list depend_input = {}; for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]); - if (call_index == 0) { - auto new_return_primitive = - graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); - graph->set_return(graph->NewCNode({new_return_primitive, call_node})); - continue; - } - InsertDependToGraph(graph->graph_id(), call_node); + depend_input.push_front(call_node); } + depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name())))); + auto depend = graph->NewCNode(std::vector(depend_input.begin(), depend_input.end())); + auto new_return_primitive = + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); + graph->set_return(graph->NewCNode({new_return_primitive, depend})); } graph->UpdateChildGraphOrder(); UpdateRealInput(graph.get()); diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index aa1050b61b8510764ee5bd37d5849f239cb9e758..d8b60cf3b3c2e4b9439ba60e6c1f50be185196ee 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -97,15 +97,16 @@ class AscendSession : public SessionBasic { void SetFinalGraphOutput(const VectorRef &vec_output); void SplitGraph(const KernelGraphPtr &graph); + // split graphs with recurse from root graph + void SplitGraphs(const KernelGraphPtr &root_graph); void LinkChildGraphs(NotNull graph); - void IRFusion(const KernelGraphPtr &graph) {} void SelectKernelGraphKernel(const KernelGraph &graph) {} void ConvertPredictModel(const KernelGraphPtr graph) {} void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} void RootGraphExecutorValidate(KernelGraph *graph) {} void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph); - KernelGraphPtr SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, const std::vector &list); + KernelGraphPtr ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, const std::vector &list); void ChildGraphCommunicationDecrease(std::vector> *anf_node_lists); // merge execution order list of child graphs diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index c6b84d57ad6fd96d277d2a942f65abf865b73265..592108d3734b4e4adca0d8f1cb53f055963bd50f 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -39,16 +39,35 @@ void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, MS_LOG(DEBUG) << "Push que:" << node->DebugString(); } } + +std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(call_node, 0); + MS_EXCEPTION_IF_NULL(item_with_index.first); + if (!AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { + return {item_with_index.first}; + } + std::vector real_inputs; + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast()); + for (const auto &child_graph : child_graphs) { + if (AnfAlgo::IsWhileTrueGraph(child_graph)) { + continue; + } + auto real_input = child_graph->output(); + auto child_real_inputs = GetCallRealOutputs(real_input); + std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs)); + } + return real_inputs; +} } // namespace std::vector KernelGraph::outputs() const { - MS_EXCEPTION_IF_NULL(output()); - if (IsPrimitiveCNode(output(), prim::kPrimMakeTuple)) { + auto graph_output = output(); + if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { auto make_tuple = output()->cast(); MS_EXCEPTION_IF_NULL(make_tuple); auto &inputs = make_tuple->inputs(); return std::vector(inputs.begin() + 1, inputs.end()); } - return std::vector(); + return std::vector(1, graph_output); } void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, @@ -587,6 +606,9 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() { void KernelGraph::UpdateChildGraphOrder() { MS_LOG(INFO) << "graph id:" << graph_id_; auto call_nodes = FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); + for (auto &old_child_graph : child_graph_order_) { + old_child_graph->set_parent_graph(nullptr); + } child_graph_order_.clear(); for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); @@ -640,6 +662,9 @@ std::set KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { } void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { + MS_EXCEPTION_IF_NULL(parameter); + MS_EXCEPTION_IF_NULL(arg); + MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(arg); if (real_inputs_.find(parameter) == real_inputs_.end()) { @@ -649,6 +674,41 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar (void)args.insert(arg); } +void KernelGraph::UpdateCallRealInput() { + MS_LOG(INFO) << "Update graph id: " << graph_id_; + for (auto &it : real_inputs_) { + auto ¶meter = it.first; + MS_EXCEPTION_IF_NULL(parameter); + auto &real_inputs = it.second; + std::set new_real_inputs; + std::set erase_real_inputs; + for (auto &real_input : real_inputs) { + // if real input is a call node ,find the child graph output act as the new real input + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0); + MS_EXCEPTION_IF_NULL(item_with_index.first); + if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { + MS_LOG(INFO) << "paramter: " << parameter->DebugString() + << " erase real input:" << item_with_index.first->DebugString(); + (void)erase_real_inputs.insert(item_with_index.first); + auto call_node_outputs = GetCallRealOutputs(item_with_index.first); + for (auto &call_node_output : call_node_outputs) { + MS_EXCEPTION_IF_NULL(call_node_output); + MS_LOG(INFO) << "paramter: " << parameter->DebugString() + << " insert real input:" << call_node_output->DebugString(); + (void)new_real_inputs.insert(call_node_output); + } + continue; + } + for (auto &erase_node : erase_real_inputs) { + (void)real_inputs.erase(erase_node); + } + for (auto &new_real_input : new_real_inputs) { + (void)real_inputs.insert(new_real_input); + } + } + } +} + std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index de55949c7bb5addccf93681b3a2c7db6d270bee3..53e15914e89a2b287b7592c7fd553e4b8ae6064e 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -127,6 +127,8 @@ class KernelGraph : public FuncGraph { void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); // used to dump ir std::string ToString() const override; + // update the real input if the node is a call + void UpdateCallRealInput(); void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } CNodePtr get_start_label() { return start_label_; } diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index b7c73469cd35ac134e3a2046ffd147f69336f811..7606952a4a51cdc1d58ebdfb44ce20c00f944d0a 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -640,16 +640,6 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP MS_EXCEPTION_IF_NULL(func_graph_node); auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node); ConstructKernelGraph(sub_func_graph); - } else if (prim->name() == kReturnOpName) { - std::vector outputs; - auto inputs = cnode->inputs(); - if (inputs.size() < 2) { - MS_LOG(EXCEPTION) << "CNode[return] must have two inputs at least, actual inputs size is " << inputs.size(); - } - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outputs)); - // add a make_tuple before return as graph output - graph->set_output(ConstructOutput(outputs, graph)); - continue; } } @@ -659,6 +649,9 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP new_cnode->set_abstract(cnode->abstract()); new_cnode->set_scope(cnode->scope()); graph->FrontBackendlMapAdd(node, new_cnode); + if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { + graph->set_return(new_cnode); + } } }