diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc index 887b9a76a127928ee427360cd3994ef817631bba..c94940b7dd15a18b933b767026fdf044d9c89ae7 100644 --- a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc @@ -47,8 +47,8 @@ void BackendCommonOptimization(const std::shared_ptr &kern common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); optimizer->AddPassManager(common_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc index b96a7af8f30188182098a079cb9bfee05ab25cb8..ddb01bde9319e206362ca251844dc575f331cd45 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc @@ -27,86 +27,33 @@ namespace mindspore { namespace opt { namespace { -bool MakeValueNode(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return false; - } - - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - TypeId infer_data_type; - if (AnfAlgo::GetOutputTensorNum(value_node) == 0) { - infer_data_type = kTypeUnknown; - } else { - infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0); - } - kernel_build_info_builder->SetOutputsDeviceType(std::vector{infer_data_type}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get()); - return true; -} - -void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node, - std::vector *plant_inputs, std::vector *dyn_input_sizes) { - MS_EXCEPTION_IF_NULL(plant_inputs); - MS_EXCEPTION_IF_NULL(dyn_input_sizes); - MS_EXCEPTION_IF_NULL(graph); - auto output_size = AnfAlgo::GetOutputTensorNum(input_node); - dyn_input_sizes->push_back(output_size); - std::vector convert_inputs; - auto kernel_graph = graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - if (input_node->isa()) { - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node); - } else { - for (size_t index = 0; index < output_size; ++index) { - auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)}, - {AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get()); - convert_inputs.emplace_back(tuple_get_item); - } - } - (void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs)); -} - void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(graph); - auto &ori_args = cnode_ptr->inputs(); - if (ori_args.size() < 1) { - return; - } std::vector plant_inputs; std::vector dyn_input_sizes; - plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]); - for (size_t i = 1; i < ori_args.size(); ++i) { - auto input_node = ori_args[i]; - if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) { + plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) { + auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i); + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) { auto input_size = AnfAlgo::GetOutputTensorNum(input_node); dyn_input_sizes.push_back(input_size); - auto cnode = input_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto inputs = cnode->inputs(); - for (size_t j = 1; j < inputs.size(); ++j) { - MS_EXCEPTION_IF_NULL(inputs[j]); - if (IsValueNode(inputs[j])) { - auto success = MakeValueNode(inputs[j]); + auto make_tuple = input_node->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) { + auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j); + MS_EXCEPTION_IF_NULL(dyn_input_node); + if (IsValueNode(dyn_input_node)) { + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto success = kernel_graph->NewValueNode(dyn_input_node->cast()); if (!success) { - MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString(); + MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString(); } } - plant_inputs.push_back(inputs[j]); + plant_inputs.push_back(dyn_input_node); } - } else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) { - ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes); } else { dyn_input_sizes.push_back(-1); plant_inputs.push_back(input_node); @@ -139,9 +86,8 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu for (auto &t : todos) { ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast()); } - } else { - ConvertMakeTupleInputToPlantInputs(func_graph, node->cast()); } + ConvertMakeTupleInputToPlantInputs(func_graph, node->cast()); return node; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc index 34ba83ef17098c18357f79ea95df79abd54c8068..68543328b166b95d75644a79a4380c076b3990f1 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc @@ -25,6 +25,38 @@ namespace mindspore { namespace opt { namespace { +CNodePtr ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(graph); + if (!AnfAlgo::IsTupleOutput(input_node)) { + MS_LOG(EXCEPTION) << "Cannot using the function to convert a not tuple output node to maketuple!"; + } + if (input_node->isa()) { + MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << input_node->DebugString(); + } + std::vector convert_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto splited_node_list = kernel_graph->SplitTupleOutputNodeToNodeList(input_node); + for (const auto &node : splited_node_list) { + if (AnfAlgo::IsTupleOutput(node)) { + convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, node)); + continue; + } + convert_inputs.emplace_back(node); + } + + auto make_tuple = graph->NewCNode(convert_inputs); + std::vector abstract_list; + auto make_tuple_input_size = AnfAlgo::GetInputTensorNum(make_tuple); + for (size_t index = 0; index < make_tuple_input_size; ++index) { + auto make_tuple_input = AnfAlgo::GetInputNode(make_tuple, index); + MS_EXCEPTION_IF_NULL(make_tuple_input); + abstract_list.emplace_back(make_tuple_input->abstract()); + } + make_tuple->set_abstract(std::make_shared(abstract_list)); + return make_tuple; +} + CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(graph); @@ -35,19 +67,25 @@ CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr std::vector types; std::vector> shapes; std::vector make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; - for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { - make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); - types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); - shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); + if (input_node->isa()) { + for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { + make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); + types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); + shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); + } + auto make_tuple = graph->NewCNode(make_tuple_inputs_list); + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); + convert_inputs.emplace_back(make_tuple); + continue; } - auto make_tuple = graph->NewCNode(make_tuple_inputs_list); - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); - convert_inputs.emplace_back(make_tuple); + convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, input_node)); } else { convert_inputs.push_back(input_node); } } - return graph->NewCNode(convert_inputs); + auto new_node = graph->NewCNode(convert_inputs); + new_node->set_abstract(cnode_ptr->abstract()); + return new_node; } } // namespace diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 3e462ca61841988bd952014e0912882a5516e619..a89b1579eab8592257141ab6d5c6badcac772e64 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -79,31 +79,6 @@ std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { return real_inputs; } -AnfNodePtr MakeValueNode(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return nullptr; - } - - ValueNodePtr new_value_node = std::make_shared(value_node->value()); - new_value_node->set_abstract(value_node->abstract()); - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { - types.push_back(kTypeUnknown); - } - kernel_build_info_builder->SetOutputsDeviceType(types); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - return new_value_node; -} - bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { if (left == right) { return true; @@ -121,6 +96,18 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { return false; } } // namespace +AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return nullptr; + } + + ValueNodePtr new_value_node = std::make_shared(value_node->value()); + new_value_node->set_abstract(value_node->abstract()); + this->SetKernelInfoForNode(new_value_node); + return new_value_node; +} + std::vector KernelGraph::outputs() const { auto graph_output = output(); if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { @@ -290,28 +277,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { MS_EXCEPTION_IF_NULL(cnode); cnode->set_abstract(std::make_shared()); CreateKernelInfoFromNewParameter(cnode); - - auto kernel_info = std::make_shared(); - std::vector feature_map_input_indexs; - // if the node only has the primitive(such as getNext) or the node's input has a feature map input - // then the node's output is a feature map output - for (size_t index = 1; index < inputs.size(); ++index) { - auto node = inputs[index]; - if (AnfAlgo::IsFeatureMapOutput(node)) { - feature_map_input_indexs.push_back(index); - } - } if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); } - if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { - kernel_info->SetFeatureMapFlag(true); - } - if (AnfAlgo::IsRealKernel(cnode)) { - AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); - AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); - } - cnode->set_kernel_info(kernel_info); + SetKernelInfoForNode(cnode); AnfAlgo::SetGraphId(graph_id_, cnode.get()); return cnode; } @@ -351,6 +320,50 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { } } +void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = std::make_shared(); + node->set_kernel_info(kernel_info); + if (node->isa()) { + std::vector feature_map_input_indexs; + kernel_info->SetFeatureMapFlag(false); + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { + if (AnfAlgo::IsFeatureMapInput(node, index)) { + kernel_info->SetFeatureMapFlag(true); + feature_map_input_indexs.push_back(index); + } + } + if (AnfAlgo::GetInputTensorNum(node) == 0) { + kernel_info->SetFeatureMapFlag(true); + } + if (AnfAlgo::IsRealKernel(node)) { + // if the node only has the primitive(such as getNext) or the node's input has a feature map input + // then the node's output is a feature map output + AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node); + AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node); + } + return; + } + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + std::vector types; + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + if (node->isa()) { + kernel_info->SetFeatureMapFlag(false); + types.emplace_back(kTypeUnknown); + } + if (node->isa()) { + auto parameter = node->cast(); + MS_EXCEPTION_IF_NULL(parameter); + bool is_weight = AnfAlgo ::IsParameterWeight(parameter); + kernel_info->SetFeatureMapFlag(!is_weight); + types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); + } + // set parameter initaial device data type + kernel_build_info_builder->SetOutputsDeviceType(types); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get()); +} + CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); auto new_cnode = std::make_shared(*cnode); @@ -366,75 +379,97 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { } ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { - ParameterPtr new_parameter = add_parameter(); + auto abstract = parameter == nullptr ? std::make_shared() : parameter->abstract(); + auto new_parameter = NewParameter(abstract); MS_EXCEPTION_IF_NULL(new_parameter); - // create kernel_info form new parameter - auto kernel_info = std::make_shared(); - size_t output_tensor_num = 1; - // if use default parameter = nullptr,it remarks create a new parameter from no parameter - if (parameter == nullptr) { - new_parameter->set_abstract(std::make_shared()); - kernel_info->SetFeatureMapFlag(true); - } else { - // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter - new_parameter->set_abstract(parameter->abstract()); + + // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter + if (parameter != nullptr) { new_parameter->set_name(parameter->name()); if (AnfAlgo::IsParameterWeight(parameter)) { new_parameter->set_default_param(parameter->default_param()); - kernel_info->SetFeatureMapFlag(false); - } else { - kernel_info->SetFeatureMapFlag(true); } } - new_parameter->set_kernel_info(kernel_info); - // create kernel_build_info for new parameter - auto kernel_build_info_builder = std::make_shared(); - // create init data type, - std::vector init_data_type = {}; - - TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); - init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); + // create kernel_info form new parameter + SetKernelInfoForNode(new_parameter); + AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); + return new_parameter; +} - // set the format of parameter to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); - // set parameter initaial device data type - kernel_build_info_builder->SetOutputsDeviceType(init_data_type); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get()); +ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) { + ParameterPtr new_parameter = add_parameter(); + new_parameter->set_abstract(abstract); + MS_EXCEPTION_IF_NULL(new_parameter); + // create kernel_info form new parameter + SetKernelInfoForNode(new_parameter); AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); return new_parameter; } +std::vector KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr ¶meter) { + MS_EXCEPTION_IF_NULL(parameter); + std::vector convert_nodes_list; + auto abstract = parameter->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + if (!abstract->isa()) { + MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString(); + } + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + for (size_t index = 0; index < tuple_abstract->size(); ++index) { + auto new_parameter = this->NewParameter((*tuple_abstract)[index]); + SetKernelInfoForNode(new_parameter); + convert_nodes_list.emplace_back(new_parameter); + } + auto new_inputs = std::make_shared>(); + auto old_inputs = inputs(); + for (const auto &input_node : old_inputs) { + if (input_node != parameter) { + new_inputs->emplace_back(input_node); + continue; + } + std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs)); + } + inputs_ = new_inputs; + return convert_nodes_list; +} + +std::vector KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString(); + } + if (node->isa()) { + return SplitTupleParameterToNodeList(node->cast()); + } + return SplitTupleValueNodeToNodeList(node->cast()); +} + std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { MS_EXCEPTION_IF_NULL(value_node); auto node_value = value_node->value(); - auto output_size = AnfAlgo::GetOutputTensorNum(value_node); std::vector convert_inputs; if (!node_value->isa()) { MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); } auto value_tuple = node_value->cast(); MS_EXCEPTION_IF_NULL(value_tuple); - if (value_tuple->size() != output_size) { - MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size() - << " is not mathced with the value node's output size" << output_size; + auto abstract = value_node->abstract(); + if (!abstract->isa()) { + MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple"; + } + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + if (tuple_abstract->size() != value_tuple->size()) { + MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range " + << tuple_abstract->size(); } for (size_t index = 0; index < value_tuple->value().size(); ++index) { auto new_value_node = std::make_shared(value_tuple->value()[index]); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)}, - {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get()); + new_value_node->set_abstract((*tuple_abstract)[index]); AddValueNodeToGraph(new_value_node); - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - kernel_info->SetFeatureMapFlag(false); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + SetKernelInfoForNode(new_value_node); AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); - AddValueNodeToGraph(new_value_node); convert_inputs.emplace_back(new_value_node); } if (!RemoveValueNodeFromGraph(value_node)) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 3ba5f333da4aa687ce488356c45f0e0c3c62cfc0..2764c71418226f5dcada5b7d4181739e7a871a32 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -54,8 +54,10 @@ class KernelGraph : public FuncGraph { void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); CNodePtr NewCNode(const CNodePtr &cnode); ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); + ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); + ValueNodePtr NewValueNode(const ValuePtr &value); ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); - std::vector SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); + std::vector SplitTupleOutputNodeToNodeList(const AnfNodePtr &node); void set_execution_order(const std::vector &order) { execution_order_ = order; } const std::vector &execution_order() const { return execution_order_; } void SetExecOrderByDefault(); @@ -166,6 +168,10 @@ class KernelGraph : public FuncGraph { private: // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); + void SetKernelInfoForNode(const AnfNodePtr &node) const; + std::vector SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); + std::vector SplitTupleParameterToNodeList(const ParameterPtr ¶meter); + AnfNodePtr MakeValueNode(const AnfNodePtr &node); void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, std::unordered_set *visited_nodes); // update node edge list diff --git a/tests/ut/cpp/session/kernel_graph_test.cc b/tests/ut/cpp/session/kernel_graph_test.cc index f24036b4aa2b0eb2828e1897e4d78b698c1f9d5e..0961b09d62071781659af2b96be222898983e1e1 100644 --- a/tests/ut/cpp/session/kernel_graph_test.cc +++ b/tests/ut/cpp/session/kernel_graph_test.cc @@ -60,7 +60,7 @@ TEST_F(KernelGraphTest, NewParameter) { auto anf_graph = std::make_shared(); auto kernel_graph = std::make_shared(); // test nullptr as input - auto new_paramter = kernel_graph->NewParameter(nullptr); + auto new_paramter = kernel_graph->NewParameter(); EXPECT_NE(new_paramter, nullptr); EXPECT_TRUE(new_paramter->isa()); EXPECT_EQ(AnfAlgo::GetOutputFormat(new_paramter, 0), kOpFormat_DEFAULT);