diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 5f896282dcbecd45411ed0336fd425a638b2df2f..81ad02e787a9d6d95881a8137559424f296cfe80 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -694,7 +694,7 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector & MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); } if (shapes.empty()) { - MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes"; + node->set_abstract(std::make_shared()); } else if (shapes.size() == 1) { // single output handle std::vector shape_int; @@ -1012,6 +1012,9 @@ std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { auto partial = switch_node->input(input_index); MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return GetValueNode(partial); + } auto partial_cnode = partial->cast(); MS_EXCEPTION_IF_NULL(partial_cnode); auto graph_node = partial_cnode->input(1); diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 166d4cc97a9be67f6a1482a56682cef7c39e3ec5..56f9550b9182d44c10098b58405c8f71d6d6bf8e 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -386,8 +386,7 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNull kg, NotNull origin_switch_inputs[kCNodeSwitchCond]}; for (size_t i = 0; i < branch_partial.size(); ++i) { // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg; - std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); @@ -444,8 +442,11 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); } -std::tuple AscendControlParser::ParsePartial(NotNull node) { +KernelGraphPtr AscendControlParser::ParsePartial(NotNull node) { if (!node.get()->isa()) { + if (IsValueNode(node)) { + return GetValueNode(node); + } MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); } // 2.1 branch kernel graph and args @@ -460,7 +461,7 @@ std::tuple AscendControlParser::ParsePartial(NotNull(partial_inputs[kCNodePartialFunc]); - return {partial_cnode, branch_kg}; + return branch_kg; } void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 2b383d7b149d06e3747e3273e3aba3d7ccb8fb60..ba1217c38dcc0ea8d38d0f313229ae079576cc9e 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -52,7 +52,7 @@ class AscendControlParser { static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, const CNodePtr &last_label); - static std::tuple ParsePartial(NotNull node); + static KernelGraphPtr ParsePartial(NotNull node); static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, NotNull from, NotNull to); diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 01b56910296f29fde65fc7ea3b9b73b4934ba642..c0bb5d4b128dd3189a9090c671c35a9de55e2007 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -247,6 +247,9 @@ static void UpdateRealInput(NotNull graph, bool split_flag) { MS_EXCEPTION_IF_NULL(switch_cnode); auto partial = switch_cnode->input(input_index); MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return {}; + } auto partial_cnode = partial->cast(); MS_EXCEPTION_IF_NULL(partial_cnode); auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 7b53afac2aebb239fbf1ef88ac85c28814ab460b..13c5a64ae237c9849916c1117c0a79fcf89d8dc6 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -357,18 +357,16 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { } else { kernel_info->SetFeatureMapFlag(true); } - // if output is a tuple tensor,now can use for loop to handle tuple tensor - output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter); } new_parameter->set_kernel_info(kernel_info); // create kernel_build_info for new parameter auto kernel_build_info_builder = std::make_shared(); // create init data type, std::vector init_data_type = {}; - for (size_t i = 0; i < output_tensor_num; i++) { - TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i); - init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); - } + + TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); + init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); + // set the format of parameter to DEFAULT_FORMAT kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); // set parameter initaial device data type diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index 2ea2453381c826d0b8279313cee5ce108d55d0d6..4c94cdde5795432b1185965c1717279aedb43993 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -590,7 +590,8 @@ TEST_F(AnfRuntimeAlgorithmTest, SetOutputInferTypeAndShape) { std::vector none_types = {}; std::vector> none_shapes = {}; EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, nullptr), std::runtime_error); - EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get()), std::runtime_error); + AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get()); + EXPECT_EQ((*add->abstract()), abstract::AbstractNone()); // set single input std::vector single_types = {kFloat32->type_id()}; std::vector> single_shapes = {{2, 32, 224, 224}};