diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 0ae4f5f69663eb5d353e0e77ac36f2f407324724..9672384e70a55c445695ed66f4e491ff1cc2f1ef 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -340,7 +340,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { } } -void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) { +void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); std::vector non_communication_op; @@ -351,6 +351,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { continue; } + graph->AddFinalOutputKernel(item_with_index.first); if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { AssignCommunicationNodeMem(kStaticMem, item_with_index.first); } else { diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index 0630b6dede58f4a5a2f0f42daf45a40414b0384e..656ef8e2e66a55beda0ff1e75e05946487e181dd 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -95,7 +95,7 @@ class KernelRuntime { #endif private: - void AssignStaticMemoryOutput(const session::KernelGraph *graph); + void AssignStaticMemoryOutput(session::KernelGraph *graph); void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); bool LaunchKernelMod(const session::KernelGraph &graph); diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc index 3b2402172b95e4bec2d1e9a04a52f5902705e297..4c1d2bf50dc40fa3fdcf2ed1cf9bcdbbf9140d32 100644 --- a/mindspore/ccsrc/ir/anf.cc +++ b/mindspore/ccsrc/ir/anf.cc @@ -25,7 +25,7 @@ #include "ir/func_graph.h" #include "ir/primitive_base.h" - +#include "utils/context/ms_context.h" #include "operator/ops.h" namespace mindspore { @@ -179,4 +179,43 @@ std::string get_id(const AnfNodePtr &node) { void reset_id() { node_ids.clear(); } } // namespace id_generator + +std::string GetCNodeTarget(const AnfNodePtr &node) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->device_target(); + if (!node->isa()) { + return default_target; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = cnode->input(0); + if (attr_input == nullptr) { + return default_target; + } + auto value_node = attr_input->cast(); + if (value_node == nullptr) { + return default_target; + } + auto value = value_node->value(); + if (value == nullptr) { + return default_target; + } + if (!value->isa()) { + return default_target; + } + auto primitive = value->cast(); + auto att_target = primitive->GetAttr("primitive_target"); + if (att_target != nullptr) { + if (!att_target->isa()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + auto target = GetValue(att_target); + if (kTargetSet.find(target) == kTargetSet.end()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + return target; + } + return default_target; +} } // namespace mindspore diff --git a/mindspore/ccsrc/ir/anf.h b/mindspore/ccsrc/ir/anf.h index 95a018af066b2f619954bcc8c45c74d1fcd67928..8a44627885116a25810bd8cd7fcfbcd7a090ca3e 100644 --- a/mindspore/ccsrc/ir/anf.h +++ b/mindspore/ccsrc/ir/anf.h @@ -448,7 +448,7 @@ void reset_id(); } // namespace id_generator using TaggedNodeMap = std::unordered_map; using TaggedGraph = std::pair; - +std::string GetCNodeTarget(const AnfNodePtr &node); } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_ANF_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc index 953f464431c9817e2081669672bfad6f2e067911..3f77c68f861a8f97a552e4c8ef370a9f98578325 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc @@ -46,6 +46,11 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { return nullptr; } + AnfNodePtr front_node; + auto kernel_graph = func_graph->cast>(); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + front_node = kernel_graph->GetFrontNodeByInternalOutput(node); + } AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); MS_LOG(DEBUG) << "====process op: " << node->DebugString(); AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); @@ -56,7 +61,12 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An return new_node; } } - return InsertTransOpForOutput(func_graph, new_node, kernel_select_); + auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); + if (kernel_graph != nullptr && front_node != nullptr) { + auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); + kernel_graph->ReplaceInternalOutput(old_node, final_node); + } + return final_node; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index f669f89b6644529a53213dfe22469e425ae9fbf5..01b56910296f29fde65fc7ea3b9b73b4934ba642 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -987,15 +987,6 @@ void AscendSession::SetFinalGraphOutput(const BaseRef &output) { } } -KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) { - auto it = graphs_.find(graph_id); - if (it == graphs_.end()) { - MS_LOG(WARNING) << "Can't find graph " << graph_id; - return nullptr; - } - return it->second; -} - void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { MS_LOG(INFO) << "Start!"; MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 904b011077d4d665432ce535134316d4672bf6bb..4774015457626829c497667c8bf7253c0d499f21 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -128,8 +128,6 @@ class AscendSession : public SessionBasic { void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); // insert depend to graph, used to attch control nodes to graph void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); - // Get graph by graph id ,if not exist return null ptr - KernelGraphPtr GetGraph(GraphId graph_id); // set child graph parameter if front arg is a anf void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); // set child graph parameter if front arg is a tensor diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 306e3351e3a14efb700fffd1725bf159518ef8d8..7b53afac2aebb239fbf1ef88ac85c28814ab460b 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -329,6 +329,9 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { FrontBackendlMapUpdate(cnode, new_cnode); } AnfAlgo::SetGraphId(graph_id_, cnode.get()); + if (IsInternalOutput(cnode)) { + ReplaceInternalOutput(cnode, new_cnode); + } return new_cnode; } @@ -872,6 +875,76 @@ void KernelGraph::PrintGraphExecuteOrder() const { } } +void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { + if (front_node == nullptr || node == nullptr) { + MS_LOG(INFO) << "Front node or node is nullptr"; + return; + } + MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); + front_to_internal_outputs_map_[front_node] = node; + internal_outputs_to_front_map_[node] = front_node; +} + +void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { + if (new_node == nullptr || node == nullptr) { + MS_LOG(INFO) << "New node or node is nullptr"; + return; + } + if (node == new_node) { + MS_LOG(INFO) << "New node and node is the same"; + return; + } + auto iter = internal_outputs_to_front_map_.find(node); + if (iter == internal_outputs_to_front_map_.end()) { + MS_LOG(INFO) << "Node is not internal output"; + return; + } + MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); + internal_outputs_to_front_map_[new_node] = iter->second; + front_to_internal_outputs_map_[iter->second] = new_node; + internal_outputs_to_front_map_.erase(iter); +} + +AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { + auto iter = front_to_internal_outputs_map_.find(front_node); + if (iter != front_to_internal_outputs_map_.end()) { + return iter->second; + } + return nullptr; +} + +bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { + if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { + return true; + } + return false; +} + +AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { + auto iter = internal_outputs_to_front_map_.find(node); + if (iter != internal_outputs_to_front_map_.end()) { + return iter->second; + } + return nullptr; +} + +void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { + if (node == nullptr) { + return; + } + (void)final_output_kernels_.insert(node); +} + +bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { + if (node == nullptr) { + return false; + } + if (final_output_kernels_.find(node) != final_output_kernels_.end()) { + return true; + } + return false; +} + std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index c7a826e5fe9c26b794134c59ca72b86127c99f7a..d6a67f3f02f1a27f1d37ec17b40fc4daba799f37 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -144,6 +144,13 @@ class KernelGraph : public FuncGraph { void PrintGraphExecuteOrder() const; const std::map> &summary_nodes() const { return summary_nodes_; } void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } + void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); + void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); + AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; + bool IsInternalOutput(const AnfNodePtr &node) const; + AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; + void AddFinalOutputKernel(const AnfNodePtr &node); + bool IsFinalOutputKernel(const AnfNodePtr &node) const; private: // remove value node form graph @@ -202,6 +209,9 @@ class KernelGraph : public FuncGraph { CNodePtr start_label_; CNodePtr end_goto_; bool null_output_; + std::unordered_map front_to_internal_outputs_map_; + std::unordered_map internal_outputs_to_front_map_; + std::set final_output_kernels_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 8382d6af9b64357a0a5582c6a099285bb5443bc5..3f893b60906d858c5713b414faea3b735315a6aa 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -95,6 +95,13 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne TypeId type_id = kNumberTypeFloat32; type_id = AnfAlgo::GetOutputInferDataType(node, output_index); std::vector temp_shape; + if (graph.IsInternalOutput(node)) { + temp_shape.emplace_back(1); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + tensor->set_device_address(address); + tensor->set_dirty(false); + return tensor; + } (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); // if in paynative mode,data only copyed to host when user want to print data @@ -172,48 +179,6 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { return new_value_node; } -std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(graph); - std::vector parameters; - std::vector pre_graph_out = {node}; - // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive - if (!AnfAlgo::IsRealKernel(node)) { - pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); - } - auto valid_inputs = graph->MutableValidInputs(); - MS_EXCEPTION_IF_NULL(valid_inputs); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { - auto parameter = graph->NewParameter(); - MS_EXCEPTION_IF_NULL(parameter); - parameter->set_abstract(abstract); - auto new_parameter = graph->NewParameter(parameter); - parameters.push_back(new_parameter); - valid_inputs->push_back(valid_input); - graph_inputs->push_back(new_parameter); - }; - for (const auto &out_node : pre_graph_out) { - MS_EXCEPTION_IF_NULL(out_node); - auto abstract = out_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; - for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { - create_parameter((*tuple_abstract)[output_idx]); - } - continue; - } - // create single parameter if is a abstract real kernel - create_parameter(out_node->abstract()); - } - return parameters; -} - size_t LoadCtrlInputTensor(const std::shared_ptr &graph, std::vector *inputs) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Load kInputCtrlTensors"; @@ -323,6 +288,103 @@ bool ExistSummaryNode(const KernelGraph *graph) { } // namespace GraphId SessionBasic::graph_sum_ = 0; + +KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) { + auto it = graphs_.find(graph_id); + if (it == graphs_.end()) { + MS_LOG(WARNING) << "Can't find graph " << graph_id; + return nullptr; + } + return it->second; +} + +void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) { + auto graph_id = GetGraphIdByNode(out_node); + if (graph_id == kInvalidGraphId) { + return; + } + auto node_graph = GetGraph(graph_id); + if (node_graph == nullptr) { + return; + } + MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString(); + auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node); + if (ref_node == nullptr) { + MS_LOG(INFO) << "No corresponding internal output for output node"; + return; + } + auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); + auto ref_real_node = real_kernel.first; + auto ref_real_node_index = real_kernel.second; + if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && + node_graph->IsFinalOutputKernel(ref_real_node)) { + auto kernel_info = ref_real_node->kernel_info(); + if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { + MS_LOG(INFO) << "No kernel info"; + return; + } + auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); + if (address == nullptr) { + MS_LOG(INFO) << "No kernel address"; + return; + } + auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); + auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); + parameter->set_kernel_info(std::make_shared()); + auto d_kernel_info = parameter->kernel_info(); + MS_EXCEPTION_IF_NULL(d_kernel_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsDeviceType({type}); + builder.SetOutputsFormat({format}); + d_kernel_info->set_select_kernel_build_info(builder.Build()); + AnfAlgo::SetOutputAddr(address, 0, parameter.get()); + } +} + +std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, + KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); + std::vector parameters; + std::vector pre_graph_out = {node}; + // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive + if (!AnfAlgo::IsRealKernel(node)) { + pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); + } + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { + auto parameter = graph->NewParameter(); + MS_EXCEPTION_IF_NULL(parameter); + parameter->set_abstract(abstract); + auto new_parameter = graph->NewParameter(parameter); + parameters.push_back(new_parameter); + valid_inputs->push_back(valid_input); + graph_inputs->push_back(new_parameter); + }; + for (const auto &out_node : pre_graph_out) { + MS_EXCEPTION_IF_NULL(out_node); + auto abstract = out_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; + for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { + create_parameter((*tuple_abstract)[output_idx]); + } + continue; + } + // create single parameter if is a abstract real kernel + create_parameter(out_node->abstract()); + InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]); + } + return parameters; +} + ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(anf); @@ -877,6 +939,29 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { auto backend_anf = graph->GetBackendAnfByFrontAnf(out); if (backend_anf != nullptr) { + auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); + auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); + MS_EXCEPTION_IF_NULL(out); + auto out_func_graph = out->func_graph(); + MS_EXCEPTION_IF_NULL(out_func_graph); + auto out_func_graph_manager = out_func_graph->manager(); + if (out_func_graph_manager == nullptr) { + return backend_anf; + } + auto node_users = out_func_graph_manager->node_users(); + auto users = node_users[out]; + bool internal_output = true; + std::string kernel_target = GetCNodeTarget(front_real_kernel.first); + for (auto user : users) { + if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { + internal_output = false; + break; + } + } + if (internal_output) { + MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); + graph->AddInternalOutput(out, backend_real_kernel.first); + } return backend_anf; } MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index b9b966d90f5d2f58264255f13b302c46caea3c9d..cf85dd02250d5733655c2ee6c2e56ead6dd193f2 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -110,6 +110,8 @@ class SessionBasic { #endif protected: + // Get graph by graph id ,if not exist return null ptr + KernelGraphPtr GetGraph(GraphId graph_id); virtual void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, @@ -127,11 +129,13 @@ class SessionBasic { BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); // create a new kernel graph and update the graph sum KernelGraphPtr NewKernelGraph(); + std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); + void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); std::unordered_map> graphs_; std::unordered_map> run_op_graphs_; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 3876f6279c5e1f59eacd1ce9fdf6d01201eb1545..80d2fc9df96eea5c5416b0f5dd5f2109e82beb61 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -52,45 +52,6 @@ const std::vector &GetMsNonlinearOps() { } namespace { -std::string GetCNodeTarget(const AnfNodePtr &node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->device_target(); - if (!node->isa()) { - return default_target; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto attr_input = cnode->input(kAnfPrimitiveIndex); - if (attr_input == nullptr) { - return default_target; - } - auto value_node = attr_input->cast(); - if (value_node == nullptr) { - return default_target; - } - auto value = value_node->value(); - if (value == nullptr) { - return default_target; - } - if (!value->isa()) { - return default_target; - } - auto primitive = value->cast(); - auto att_target = primitive->GetAttr("primitive_target"); - if (att_target != nullptr) { - if (!att_target->isa()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; - } - auto target = GetValue(att_target); - if (kTargetSet.find(target) == kTargetSet.end()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; - } - return target; - } - return default_target; -} - bool ContainMultiTarget(const std::vector &nodes) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr);