diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/common.py b/mindspore/_extends/parallel_compile/tbe_compiler/common.py index 1aeba9889dee858627d932e189600a8c2797d9a6..3d55cf60a2cf907605c841d7e7f21dedd54c37b0 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/common.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/common.py @@ -15,13 +15,6 @@ """tbe common""" import json import os -from attrdict import AttrDict - -class ParamType(AttrDict): - Required = "required" - Dynamic = "dynamic" - Optional = "optional" - class TBEException(Exception): """tbe exception class""" @@ -112,7 +105,7 @@ def get_input_output(io_info, args): if len(item) > 1: arg.append(info) else: - if info['param_type'] == ParamType.Dynamic: + if info['param_type'] == 'dynamic': arg.append(info) args.append(arg) else: diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index 5e15055a08956b8cadb0c60484a817ee107d8c0c..26ab826a7fde31d1423814069b13dafb35cc44ff 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -542,7 +542,7 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr GetNeedActiveStreams(graph_ptr); MS_LOG(INFO) << "after finish stream assign"; - PrintGraphExeOrders(graph_ptr); + graph_ptr->PrintGraphExecuteOrder(); // Get info for D Model generator::IRModelUtil::GetInstance().set_event_num(total_event_num()); @@ -810,26 +810,6 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptrset_execution_order(exe_orders); } - -void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr &graph_ptr) { - MS_EXCEPTION_IF_NULL(graph_ptr); - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { - auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); - MS_LOG(INFO) << "node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id[" - << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" - << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], event_id[" - << GetValue(primitive->GetAttr(kAttrEventId)) << "]"; - } else { - MS_LOG(INFO) << "node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" - << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" - << AnfAlgo::GetStreamId(cur_cnode_ptr) << "]"; - } - } -} } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h old mode 100755 new mode 100644 index b6f6bfd4798ac987dae60b7c01dca7efc1b77c6e..7728e61fb053e60a6dd4e1abf89f4d6135fa1ddc --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h @@ -87,7 +87,6 @@ class AscendStreamAssign { void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id); void UpdateStreamId(const std::shared_ptr &graph_ptr); void UpdateEventId(const std::shared_ptr &graph_ptr); - void PrintGraphExeOrders(const std::shared_ptr &graph_ptr); void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr); void SetCommonStreamNum(uint32_t cur_stream_id); diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 2853caa732880b15f5f619341089c818bd44fdcf..949b1af2a8de69eda10363e2dbe70f4b206ed572 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -32,7 +32,6 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; namespace mindspore { namespace session { - void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { for (auto &iter : graph_id_map) { auto &kg = iter.second; @@ -356,12 +355,6 @@ void AscendControlParser::ExecutorValidate(NotNull root_graph) { std::vector AscendControlParser::RecurseGraph(NotNull graph, const NotNull *> memo) { MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; - auto print_vector = [&](std::vector vec) -> void { - MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order"; - for (size_t i = 0; i < vec.size(); i++) { - MS_LOG(INFO) << "[" << i << "][" << vec[i]->DebugString() << "]"; - } - }; if (memo->find(graph) != memo->end()) { return {}; } @@ -403,7 +396,7 @@ std::vector AscendControlParser::RecurseGraph(NotNull } } graph->set_execution_order(execution_order); - print_vector(graph->execution_order()); + graph->PrintGraphExecuteOrder(); return execution_order; } @@ -474,6 +467,5 @@ void AscendControlParser::UpdateChildGraphOrder(NotNull kg) { } kg->set_child_graph_order(child_graph_order); } - } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index bb1aee76afd95c1d2fab8d3bd6263f61763abb44..037077766e5a12f86466e547f0bf9f83e70b86b2 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -26,7 +26,6 @@ namespace mindspore { namespace session { - class AscendControlParser { public: static void ChildGraphDataAssign(const std::map &graph_id_map); diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index f1b15b27ab1048d076e2596854bf6a7c3c4805fd..c23763b2b2a524cd8576c117c73f7a8f2380e706 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -206,39 +206,40 @@ static std::vector> GetChildList(const std::vector ¶meters, const std::vector &args, + KernelGraph *child_graph) { + MS_EXCEPTION_IF_NULL(child_graph); + MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id(); + if (args.empty()) { + return; + } + if (parameters.size() != args.size()) { + MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() + << " and args size:" << args.size() << " not equal!"; + } + child_graph->SetExecOrderByDefault(); + for (size_t i = 0; i < parameters.size(); i++) { + if (args[i] == parameters[i]) { + child_graph->SetRealInput(parameters[i], args[i]); + MS_LOG(INFO) << "Parameter and arg are same"; + continue; + } + // if arg is a parameter ,then reuse this parameter + if (args[i]->isa()) { + MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id() + << " reuse parameter:" << args[i]->DebugString() + << " of graph:" << AnfAlgo::GetGraphId(args[i].get()); + child_graph->ReplaceNode(parameters[i], args[i]); + continue; + } + child_graph->SetRealInput(parameters[i], args[i]); + } +} + // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] static void UpdateRealInput(NotNull graph) { auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); - auto bind_call_arg_with_parameter = [&](const std::vector ¶meters, - const std::vector &args, KernelGraph *child_graph) -> void { - MS_EXCEPTION_IF_NULL(child_graph); - MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id(); - if (args.empty()) { - return; - } - if (parameters.size() != args.size()) { - MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() - << " and args size:" << args.size() << " not equal!"; - } - child_graph->SetExecOrderByDefault(); - for (size_t i = 0; i < parameters.size(); i++) { - if (args[i] == parameters[i]) { - child_graph->SetRealInput(parameters[i], args[i]); - MS_LOG(INFO) << "Parameter and arg are same"; - continue; - } - // if arg is a parameter ,then reuse this parameter - if (args[i]->isa()) { - MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id() - << " reuse parameter:" << args[i]->DebugString() - << " of graph:" << AnfAlgo::GetGraphId(args[i].get()); - child_graph->ReplaceNode(parameters[i], args[i]); - continue; - } - child_graph->SetRealInput(parameters[i], args[i]); - } - }; for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); @@ -247,7 +248,7 @@ static void UpdateRealInput(NotNull graph) { std::vector real_args = std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); std::vector child_inputs = child_graphs[0]->inputs(); - bind_call_arg_with_parameter(child_inputs, real_args, child_graphs[0].get()); + BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get()); call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); } else if (child_graphs.size() == 2) { auto get_partial_args = [&](size_t input_index) -> std::vector { @@ -264,8 +265,8 @@ static void UpdateRealInput(NotNull graph) { std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); return ret; }; - bind_call_arg_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); - bind_call_arg_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); + BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); + BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); } } } @@ -1429,10 +1430,7 @@ void AscendSession::SyncInitialTenosrToDevice() { } } -std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list) { - MS_EXCEPTION_IF_NULL(new_kernel_graph); - MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id(); +static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector &list) { // count the output of every anf node std::set has_output_nodes; for (auto &anf_node : list) { @@ -1440,6 +1438,28 @@ std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPt (void)has_output_nodes.insert(input); } } + + auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); + std::vector make_tuple_inputs = {make_tuple_primitve}; + int output_idx = 0; + for (auto &anf_node : list) { + if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { + new_kernel_graph->set_return(anf_node); + } + if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { + MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString(); + make_tuple_inputs.push_back(anf_node); + } + } + if (new_kernel_graph->get_return() == nullptr) { + new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); + } +} + +std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list) { + MS_EXCEPTION_IF_NULL(new_kernel_graph); + MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id(); MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); std::vector call_node_inputs; std::vector new_graph_inputs; @@ -1479,22 +1499,9 @@ std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPt MS_EXCEPTION_IF_NULL(graph_inputs); graph_inputs->clear(); std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); + MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); - auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); - std::vector make_tuple_inputs = {make_tuple_primitve}; - int output_idx = 0; - for (auto &anf_node : list) { - if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { - new_kernel_graph->set_return(anf_node); - } - if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { - MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString(); - make_tuple_inputs.push_back(anf_node); - } - } - if (new_kernel_graph->get_return() == nullptr) { - new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); - } + ConstructSplitedGraphOutput(new_kernel_graph, list); MS_LOG(INFO) << "end"; return call_node_inputs; } @@ -1516,6 +1523,30 @@ void AscendSession::SplitGraphs(NotNull root_graph) { RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); } +AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull graph, + const std::vector &child_graph_list) { + // if child graph list only has a call ,then return the exist call + if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { + return child_graph_list[0]; + } + // create new child graph + auto child_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(child_graph); + // create new value node to bind child graph + auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); + std::vector new_call_input = {NewValueNode(std::make_shared(prim::kPrimCall->name())), + graph_value_node}; + // set the graph id of all node of child graph + for (auto &child_graph_node : child_graph_list) { + AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); + } + auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); + std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); + auto new_call = graph->NewCNode(new_call_input); + AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); + return new_call; +} + void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims) { MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); auto apply_list = GetCNodes(TopoSort(graph->get_return())); @@ -1523,32 +1554,10 @@ void AscendSession::SplitGraph(NotNull graph, const std::set> child_graph_lists = GetChildList(apply_list, cut_prims); - auto bind_new_call_to_new_graph = [&](std::vector child_graph_list) -> AnfNodePtr { - // if child graph list only has a call ,then return the exist call - if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { - return child_graph_list[0]; - } - // create new child graph - auto child_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(child_graph); - // create new value node to bind child graph - auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); - std::vector new_call_input = {NewValueNode(std::make_shared(prim::kPrimCall->name())), - graph_value_node}; - // set the graph id of all node of child graph - for (auto &child_graph_node : child_graph_list) { - AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); - } - auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); - std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); - auto new_call = graph->NewCNode(new_call_input); - AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); - return new_call; - }; if (child_graph_lists.size() > 1) { std::list depend_input = {}; for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { - auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]); + auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]); MS_EXCEPTION_IF_NULL(call_node); // if call node is the last call of true graph,no need create child graph after that auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); @@ -1605,6 +1614,5 @@ void AscendSession::RecurseCompileGraph(NotNull graph, const Not RecurseCompileGraph(NOT_NULL(child_graph), memo); } } - } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 529304714c451926a338bb4cd61bbaa224617a92..0c31c4c77e71199aa37f00f9e9484098ed484973 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -107,6 +107,7 @@ class AscendSession : public SessionBasic { const std::vector &list); void RecurseCompileGraph(NotNull graph, const NotNull *> memo); void RecurseSplitGraph(NotNull graph, const NotNull *> memo); + AnfNodePtr BindNewCallToNewGraph(NotNull graph, const std::vector &child_graph_list); // merge execution order list of child graphs void MergeGraphExecOrder(); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 902375c1557c18e9e2a0fa599f7d136fb4b6abd4..d5c0036ed38ba73b1988f872f3e7e3a9a4dc3b4f 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -735,6 +735,26 @@ void KernelGraph::UpdateCallRealInput() { real_inputs_ = real_inputs_map; } +void KernelGraph::PrintGraphExecuteOrder() const { + MS_LOG(INFO) << "graph:" << graph_id_ << "execution order"; + for (size_t i = 0; i < execution_order_.size(); i++) { + CNodePtr cur_cnode_ptr = execution_order_[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { + auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); + MS_LOG(INFO) << "index[" << i << "], node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id[" + << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" + << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], event_id[" + << GetValue(primitive->GetAttr(kAttrEventId)) << "], node info[" + << cur_cnode_ptr->DebugString() << "]"; + } else { + MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" + << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" + << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"; + } + } +} + std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } KernelGraph::~KernelGraph() { diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 98a007d1a10e3024764de85dbdfd6facbb203b92..8c8ba5f8b518384a2c4cb20b4ae07db930609607 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -136,6 +136,7 @@ class KernelGraph : public FuncGraph { CNodePtr get_end_goto() { return end_goto_; } bool get_output_null() { return null_output_; } void set_output_null(bool is_output_null) { null_output_ = is_output_null; } + void PrintGraphExecuteOrder() const; private: // remove value node form graph diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index d65a04c9d15d4754e4960ffe0f9d94f1e3f08119..b1bfefcac1aeb326fa6156a324ba57ab06124463 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -563,7 +563,6 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP // if input is a ValueNode FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { - MS_LOG(INFO) << "FuncGraph: " << child_graph->ToString() << " has been transformed to KernelGraph."; is_trace_back = true; } else { (void)ConstructKernelGraph(child_graph); @@ -587,29 +586,34 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP } // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. graph->set_output_null(is_trace_back); + AddParameterToGraphInputs(func_graph->parameters(), graph.get()); + MS_EXCEPTION_IF_NULL(context_); + FuncGraphManagerPtr manager = context_->manager(); + if (manager) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } + graph->SetExecOrderByDefault(); + return graph; +} + +void SessionBasic::AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); auto graph_inputs = graph->MutableInputs(); MS_EXCEPTION_IF_NULL(graph_inputs); graph_inputs->clear(); - for (auto ¶meter : func_graph->parameters()) { + for (auto ¶meter : parameters) { MS_EXCEPTION_IF_NULL(parameter); auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); if (backend_parameter == nullptr) { // for example "def f(x,y,z) {return x + y}", parameter z in unused - CreateNewParameterFromParameter(parameter, false, graph.get()); + CreateNewParameterFromParameter(parameter, false, graph); MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); continue; } MS_LOG(INFO) << "graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString(); graph_inputs->push_back(backend_parameter); } - MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = context_->manager(); - if (manager) { - manager->AddFuncGraph(graph); - graph->set_manager(manager); - } - graph->SetExecOrderByDefault(); - return graph; } // run graph steps diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index 7c1a3fe96622eead52eac6494e70f4043c308055..142c5b68be4cf90b940ce8262377d6c01b58cf1b 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -118,6 +118,7 @@ class SessionBasic { ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); std::unordered_map> graphs_; std::unordered_map> run_op_graphs_;